Compute modules

Compute tasks for lab data processing and analysis.

This module provides compute task classes for running analyses on lab data. Tasks can be scheduled and run on dedicated containers using job schedulers.

Available compute tasks: - SpksCompute: Spike sorting using Kilosort/Phy via SPKS - DeeplabcutCompute: Animal pose estimation using DeepLabCut

Each compute task can be: - Scheduled to run on compute clusters via SLURM/PBS - Executed in isolated Singularity/Docker containers - Tracked and monitored through the database - Configured via user preferences

The compute tasks handle: - Input/output file management - Container and environment setup - Job scheduling and resource allocation - Progress tracking and error handling - Result storage and validation

BaseCompute

Source code in labdata/compute/utils.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
class BaseCompute():
    name = None
    container = 'labdata-base'
    cuda = False
    ec2 = dict(small = dict(instance_type = 'g4dn.2xlarge'),   # 8 cpus, 32 GB mem, 200 GB nvme, 1 gpu
               large = dict(instance_type = 'g6.4xlarge',
                            availability_zone = 'us-west-2b')) # 16 cpus, 64 GB mem, 600 GB nvme, 1 gpu

    def __init__(self,job_id, project = None, allow_s3 = None, keep_intermediate = False):
        '''
        Executes a computation on a dataset, that can be remote or local
        Uses a singularity/apptainer image if possible
        '''
        self.file_filters = ['.'] # selects all files...
        self.parameters = dict()
        self.schema = load_project_schema(project)
        self.keep_intermediate = keep_intermediate
        self.job_id = job_id
        if not self.job_id is None:
            self._check_if_taken()

        self.paths = None
        self.local_path = Path(prefs['local_paths'][0])
        self.scratch_path = Path(prefs['scratch_path'])
        self.assigned_files = None
        self.dataset_key = None
        self.is_container = False
        if allow_s3 is None:
            self.allow_s3 = prefs['allow_s3_download']
        if 'LABDATA_CONTAINER' in os.environ.keys():
            # then it is running inside a container
            self.is_container = True
        #self.is_ec2 = False # then files should be taken from s3

    def _init_job(self): # to run in the init function
        if not self.job_id is None:
            with self.schema.dj.conn().transaction:
                self.jobquery = (self.schema.ComputeTask() & dict(job_id = self.job_id))
                job_status = self.jobquery.fetch(as_dict = True)
                if len(job_status):
                    if not job_status[0]['task_waiting']:
                        print(f'Checking job_status - task was not waiting: {job_status}', flush = True)
                        if 'SLURM_RESTART_COUNT' in os.environ.keys():
                            # then the job is running on slurm.. its a putative restart, try to run it..
                            self.set_job_status(job_status = 'WORKING',
                                                job_waiting = 0)
                        else:
                            print(f"Compute task [{self.job_id}] is already taken.")
                            print(job_status, flush = True)
                            return # exit.
                    else:
                        self.set_job_status(job_status = 'WORKING',
                                            job_waiting = 0)

                        def cleanup_function(job_id = self.job_id):
                            # if it quits then register as canceled and put as waiting
                            print('Running the cleanup function.', flush = True)
                            status = (self.schema.ComputeTask() & dict(job_id = job_id)).fetch(as_dict = True)[0]
                            if status['task_status'] in ['WORKING']:
                                self.schema.ComputeTask.update1(dict(job_id = job_id,
                                                                     task_status = 'CANCELLED',
                                                                     task_waiting = 1,
                                                                     task_endtime = datetime.now()))

                        self.cleanup_function = cleanup_function       
                        self.register_safe_exit()
                        par = json.loads(job_status[0]['task_parameters'])
                        for k in par.keys():
                            self.parameters[k] = par[k]
                        self.assigned_files = pd.DataFrame((self.schema.ComputeTask.AssignedFiles() & dict(job_id = self.job_id)).fetch())
                        self.dataset_key = dict(subject_name = job_status[0]['subject_name'],
                                                session_name = job_status[0]['session_name'],
                                                dataset_name = job_status[0]['dataset_name'])
                        if '--multisession' in job_status[0]['task_cmd']:
                            self.dataset_key = (self.schema.Dataset() &
                                                (self.schema.Dataset.DataFiles & self.assigned_files)).proj().fetch(as_dict = True)
                        # delete the job if has --delete-on-complete
                        job_status = self.jobquery.fetch(as_dict = True)
                        if '--keep-files' in job_status[0]['task_cmd']:
                            self.keep_intermediate = True
                else:
                    # that should just be a problem to fix
                    raise ValueError(f'job_id {self.job_id} does not exist.')

    def register_safe_exit(self):
        import safe_exit
        safe_exit.register(self.cleanup_function)

    def unregister_safe_exit(self):
        import safe_exit
        safe_exit.unregister(self.cleanup_function)

    def get_files(self, dset, allowed_extensions=[]):
        '''
        Gets the paths and downloads from S3 if needed.
        '''
        if type(dset) is list:
            # then it is a list of dicts, convert to DataFrame
            dset = pd.DataFrame(dset)
        files = dset.file_path.values
        print('---')
        print(files)
        print('---')
        storage = dset.storage.values
        localpath = Path(prefs['local_paths'][0])
        self.files_existed = True
        localfiles = [find_local_filepath(f,
                                          allowed_extensions = allowed_extensions) for f in files]
        localfiles = np.unique(list(filter(lambda x: not x is None,localfiles)))
        if not len(localfiles) >= len(self.dataset_key): # tries to have at least one file per dataset (check this assumption, it is here because of the "allowed extensions")
            # then you can try downloading the files
            if self.allow_s3: # get the files from s3
                #TODO: then it should download using "File"
                from ..s3 import copy_from_s3
                for s in np.unique(storage):
                    # so it can work with multiple storages
                    srcfiles = [f for f in files[storage == s]]
                    dstfiles = [localpath/f for f in srcfiles]
                print(f'Downloading {len(srcfiles)} files from S3 [{s}].')
                copy_from_s3(srcfiles,dstfiles,storage_name = s)
                localfiles = np.unique([find_local_filepath(
                    f,
                    allowed_extensions = allowed_extensions) for f in files])
                if len(localfiles):
                    self.files_existed = False # delete the files in the end if they were not local.
            else:
                print(files, localpath)
                raise(ValueError('Files not found locally, set allow_s3 in the preferences to download.'))
        return localfiles

    def place_tasks_in_queue(self,datasets,task_cmd = None, force_submit = False, multisession = False):
        # overwride this to submit special compute tasks (e.g. SpksCompute)
        return self._place_tasks_in_queue(datasets, task_cmd = task_cmd,
                                   force_submit = force_submit,
                                   multisession = multisession)

    def _place_tasks_in_queue(self,datasets, task_cmd = None, force_submit = False, multisession = False, parameters = None):
        ''' This will put the tasks in the queue for each dataset.
        If the task and parameters are the same it will return the job_id instead.

        '''
        if parameters is None:
            parameters = self.parameters # so we can pass multiple parameters (e.g. in the multiprobe case)
        job_ids = []
        if datasets is None:
            datasets = [None]
        if multisession:
            print('Combining data from multiple sessions/datasets.')
            datasets = [datasets]
        for dataset in datasets:
            if not dataset is None: # then there are no associated files.
                files = pd.DataFrame((self.schema.Dataset.DataFiles() & dataset).fetch())
                idx = []
                for f in self.file_filters:
                    idx += list(filter(lambda x: not x is None,[i if f in s else None for i,s in enumerate(
                        files.file_path.values)]))
                if len(idx) == 0:
                    raise ValueError(f'Could not find valid Dataset.DataFiles for {dataset}')
                files = files.iloc[idx]
                if type(dataset) is dict:
                    key = dict(dataset,task_name = self.name)
                else:
                    key = dict(dataset[0],task_name = self.name)
                exists = self.schema.ComputeTask() & key
                if len(exists):
                    d = pd.DataFrame(exists.fetch())
                    # if any(d.task_status.values=='WORKING'):
                    #     print('A task was running for this dataset, stop or delete it first.')
                    #     print(key)
                    #     continue

                    idx = np.where(np.array(d.task_parameters.values) == json.dumps(parameters))[0]
                    if len(idx):
                        job_id = d.iloc[idx].job_id.values[0]
                        print(f'There is a task to analyse dataset {key} with the same parameters. [{job_id}]')
                        if force_submit:
                            print('Deleting the previous job because force_submit is set.')
                            with self.schema.dj.conn().transaction:
                                self.schema.dj.config['safemode'] = False
                                # delete part table because the reference priviledge is not sufficient
                                (self.schema.ComputeTask.AssignedFiles & f'job_id = {job_id}').delete(force = True) 
                                (self.schema.ComputeTask & f'job_id = {job_id}').delete() # deleting a previous job because of force_submit
                                self.schema.dj.config['safemode'] = True
                        else:
                            continue
            else:
                key = dict(task_name = self.name)
                files = None

            with self.schema.dj.conn().transaction:
                job_id = self.schema.ComputeTask().fetch('job_id')
                if len(job_id):
                    job_id = np.max(job_id) + 1 
                else:
                    job_id = 1
                if not task_cmd is None:
                    if len(task_cmd) >1999:
                        task_cmd = task_cmd[:1999]
                self.schema.ComputeTask().insert1(dict(key,
                                                       job_id = job_id,
                                                       task_waiting = 1,
                                                       task_status = "WAITING",
                                                       task_target = None,
                                                       task_host = None,
                                                       task_cmd = task_cmd,
                                                       task_parameters = json.dumps(parameters),
                                                       task_log = None))
                if not files is None:
                    self.schema.ComputeTask.AssignedFiles().insert([dict(job_id = job_id,
                                                                         storage = f.storage,
                                                                         file_path = f.file_path)
                                                                for i,f in files.iterrows()])
                job_ids.append(job_id)
        return job_ids

    def find_datasets(self,subject_name = None, session_name = None, dataset_name = None):
        '''
        Find datasets to analyze, this function will search in the proper tables if datasets are available.
        Has to be implemented per Compute class since it varies.
        '''
        raise NotImplementedError('The find_datasets method has to be implemented.')

    def secondary_parse(self,secondary_arguments, parameter_number = None):
        self._secondary_parse(secondary_arguments, parameter_number)

    def _secondary_parse(self,secondary_arguments):
        return

    def _check_if_taken(self):
        if not self.job_id is None:
            self.jobquery = (self.schema.ComputeTask() & dict(job_id = self.job_id))
            job_status = self.jobquery.fetch(as_dict = True)
            if len(job_status):
                if job_status[0]['task_waiting']:
                    return
                else:
                    print(job_status, flush = True)
                    raise ValueError(f'job_id {self.job_id} is already taken.')
                    return # exit.
            else:
                raise ValueError(f'job_id {self.job_id} does not exist.')
            # get the paths?
            #self.src_paths = pd.DataFrame((ComputeTask.AssignedFiles() &
            #                               dict(job_id = self.job_id)).fetch())
            #if not len(self.src_paths):
            #    self.set_job_status(job_status = 'FAILED',
            #                        job_log = f'Could not find files for {self.job_id} in ComputeTask.AssignedFiles.')
            #    raise ValueError(f'Could not find files for {self.job_id} in ComputeTask.AssignedFiles.')
        else:
            raise ValueError(f'Compute: job_id not specified.')

    def compute(self):
        '''This calls the compute function. 
If "use_s3" is true it will download the files from s3 when needed.'''
        try:
            if not self.job_id is None:
                dd = dict(job_id = self.job_id,
                          task_starttime = datetime.now())
                self.schema.ComputeTask().update1(dd)
            self._compute() # can use the src_paths
        except Exception as err:
            # log the error
            print(f'There was an error processing job {self.job_id}.')
            err =  str(traceback.format_exc()) + "ERROR" +str(err)
            print(err)

            if len(err) > 1999: # then get only the last part of the error.
                err = err[-1900:]
            if type(err) is str:
                # avoid encoding errors
                err = err.encode('utf-8')
            self.set_job_status(job_status = 'FAILED',
                                job_log = f'{err}')
            return

        # get the job from the DB if the status is not failed, mark completed (remember to clean the log)
        self.jobquery = (self.schema.ComputeTask() & dict(job_id = self.job_id))
        job_status = self.jobquery.fetch(as_dict = True)

        if not job_status[0]['task_status'] in ['FAILED']:
            self._post_compute() # so the rules can insert tables and all.
            # delete the job if has --delete-on-complete
            if '--delete-on-complete' in job_status[0]['task_cmd']:
                self.jobquery.delete(safemode = False)
                self.job_id = None
                self.unregister_safe_exit()
            else:
                self.set_job_status(job_status = 'COMPLETED')

        if not self.job_id is None: # set the complete time
            dd = dict(job_id = self.job_id,
                      task_endtime = datetime.now())
            self.schema.ComputeTask().update1(dd)

    def set_job_status(self, job_status = None, job_log = None,job_waiting = 0):
        from ..schema import ComputeTask
        if not self.job_id is None:
            dd = dict(job_id = self.job_id,
                      task_waiting = job_waiting,
                      task_host = prefs['hostname']) # so we know where it failed.)
            if not job_status is None:
                dd['task_status'] = job_status
            if not job_log is None:
                dd['task_log'] = job_log  
                if type(dd['task_log']) is str:
                    #prevent error due to unsupported characters
                    dd['task_log'] = dd['task_log'].encode('utf-8')
            self.schema.ComputeTask.update1(dd)
            if not job_status is None:
                if not 'WORK' in job_status: # display the message
                    print(f'Check job_id {self.job_id} : {job_status}')

    def _post_compute(self):
        '''
        Inserts the data to the database
        '''
        return

    def _compute(self):
        '''
        Runs the compute job on a scratch folder.
        '''
        return

__init__(job_id, project=None, allow_s3=None, keep_intermediate=False)

Executes a computation on a dataset, that can be remote or local Uses a singularity/apptainer image if possible

Source code in labdata/compute/utils.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def __init__(self,job_id, project = None, allow_s3 = None, keep_intermediate = False):
    '''
    Executes a computation on a dataset, that can be remote or local
    Uses a singularity/apptainer image if possible
    '''
    self.file_filters = ['.'] # selects all files...
    self.parameters = dict()
    self.schema = load_project_schema(project)
    self.keep_intermediate = keep_intermediate
    self.job_id = job_id
    if not self.job_id is None:
        self._check_if_taken()

    self.paths = None
    self.local_path = Path(prefs['local_paths'][0])
    self.scratch_path = Path(prefs['scratch_path'])
    self.assigned_files = None
    self.dataset_key = None
    self.is_container = False
    if allow_s3 is None:
        self.allow_s3 = prefs['allow_s3_download']
    if 'LABDATA_CONTAINER' in os.environ.keys():
        # then it is running inside a container
        self.is_container = True

compute()

This calls the compute function. If "use_s3" is true it will download the files from s3 when needed.

Source code in labdata/compute/utils.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    def compute(self):
        '''This calls the compute function. 
If "use_s3" is true it will download the files from s3 when needed.'''
        try:
            if not self.job_id is None:
                dd = dict(job_id = self.job_id,
                          task_starttime = datetime.now())
                self.schema.ComputeTask().update1(dd)
            self._compute() # can use the src_paths
        except Exception as err:
            # log the error
            print(f'There was an error processing job {self.job_id}.')
            err =  str(traceback.format_exc()) + "ERROR" +str(err)
            print(err)

            if len(err) > 1999: # then get only the last part of the error.
                err = err[-1900:]
            if type(err) is str:
                # avoid encoding errors
                err = err.encode('utf-8')
            self.set_job_status(job_status = 'FAILED',
                                job_log = f'{err}')
            return

        # get the job from the DB if the status is not failed, mark completed (remember to clean the log)
        self.jobquery = (self.schema.ComputeTask() & dict(job_id = self.job_id))
        job_status = self.jobquery.fetch(as_dict = True)

        if not job_status[0]['task_status'] in ['FAILED']:
            self._post_compute() # so the rules can insert tables and all.
            # delete the job if has --delete-on-complete
            if '--delete-on-complete' in job_status[0]['task_cmd']:
                self.jobquery.delete(safemode = False)
                self.job_id = None
                self.unregister_safe_exit()
            else:
                self.set_job_status(job_status = 'COMPLETED')

        if not self.job_id is None: # set the complete time
            dd = dict(job_id = self.job_id,
                      task_endtime = datetime.now())
            self.schema.ComputeTask().update1(dd)

find_datasets(subject_name=None, session_name=None, dataset_name=None)

Find datasets to analyze, this function will search in the proper tables if datasets are available. Has to be implemented per Compute class since it varies.

Source code in labdata/compute/utils.py
469
470
471
472
473
474
def find_datasets(self,subject_name = None, session_name = None, dataset_name = None):
    '''
    Find datasets to analyze, this function will search in the proper tables if datasets are available.
    Has to be implemented per Compute class since it varies.
    '''
    raise NotImplementedError('The find_datasets method has to be implemented.')

get_files(dset, allowed_extensions=[])

Gets the paths and downloads from S3 if needed.

Source code in labdata/compute/utils.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def get_files(self, dset, allowed_extensions=[]):
    '''
    Gets the paths and downloads from S3 if needed.
    '''
    if type(dset) is list:
        # then it is a list of dicts, convert to DataFrame
        dset = pd.DataFrame(dset)
    files = dset.file_path.values
    print('---')
    print(files)
    print('---')
    storage = dset.storage.values
    localpath = Path(prefs['local_paths'][0])
    self.files_existed = True
    localfiles = [find_local_filepath(f,
                                      allowed_extensions = allowed_extensions) for f in files]
    localfiles = np.unique(list(filter(lambda x: not x is None,localfiles)))
    if not len(localfiles) >= len(self.dataset_key): # tries to have at least one file per dataset (check this assumption, it is here because of the "allowed extensions")
        # then you can try downloading the files
        if self.allow_s3: # get the files from s3
            #TODO: then it should download using "File"
            from ..s3 import copy_from_s3
            for s in np.unique(storage):
                # so it can work with multiple storages
                srcfiles = [f for f in files[storage == s]]
                dstfiles = [localpath/f for f in srcfiles]
            print(f'Downloading {len(srcfiles)} files from S3 [{s}].')
            copy_from_s3(srcfiles,dstfiles,storage_name = s)
            localfiles = np.unique([find_local_filepath(
                f,
                allowed_extensions = allowed_extensions) for f in files])
            if len(localfiles):
                self.files_existed = False # delete the files in the end if they were not local.
        else:
            print(files, localpath)
            raise(ValueError('Files not found locally, set allow_s3 in the preferences to download.'))
    return localfiles

PopulateCompute

Bases: BaseCompute

Source code in labdata/compute/utils.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
class PopulateCompute(BaseCompute):
    container = 'labdata-base'
    cuda = False
    name = 'populate'
    url = 'http://github.com/jcouto/labdata'
    def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
        super(PopulateCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3)
        self.file_filters = None
        # default parameters
        self.parameters = dict(imports = 'labdata.schema',
                               table = 'UnitMetrics',
                               processes = 10)

        self._init_job() # gets the parameters

    def _secondary_parse(self,arguments,parameter_number):
        '''
        Handles parsing the command line interface
        '''
        import argparse
        parser = argparse.ArgumentParser(
            description = 'Populate arbitrary tables',
            usage = 'populate -- -t <TABLE> -i <IMPORTS>')

        parser.add_argument('table',action='store',type = str,
                            help = 'Table to populate')
        parser.add_argument('-s','--stop-on-errors',action='store_true',default= False,
                            help = 'Stop on errors (negates suppress_errors)')
        parser.add_argument('-r','--restrictions',action='store',default= '',
                            help = 'Restrictions to the populate table (dict(X = "x")) or completed_today (to run for sessions that were completed less than 24h ago)')
        parser.add_argument('-i','--imports',action='store',default= 'labdata.schema',type = str,
                            help = 'import modules to load the table')
        parser.add_argument('-p','--processes',
                            action='store', default=1, type = int,
                            help = "Required imports.")
        parser.add_argument

        args = parser.parse_args(arguments[1:])
        self.parameters = dict(table = args.table,
                               imports = args.imports,
                               processes = args.processes,
                               suppress_errors = not args.stop_on_errors,
                               restrictions = args.restrictions)
        # try the import and check if the default container exists.
        if not self.parameters["imports"] in ['none','']:
            to_import = self.parameters["table"]
            if '.' in to_import: # to import plugins
                to_import = to_import.split('.')[0]
            exec(f'from {self.parameters["imports"]} import {to_import}')
        table = eval(f'{self.parameters["table"]}')
        if hasattr(table,'default_container'):
            self.container = table.default_container
        print(self.parameters)

    def find_datasets(self):
        return

    def _compute(self):
        # import
        if not self.parameters["imports"] in ['none','']:
            to_import = self.parameters["table"]
            if '.' in to_import: # to import plugins
                to_import = to_import.split('.')[0]
            exec(f'from {self.parameters["imports"]} import {to_import}')
        processes = 1
        # check nprocesses
        if 'processes' in self.parameters.keys():
            processes = int(self.parameters['processes'])
        # submit populate
        suppress_errors = self.parameters['suppress_errors']
        if self.parameters['restrictions'] == '':
            exec(f'{self.parameters["table"]}.populate(suppress_errors={suppress_errors}, processes = {processes}, display_progress = True)')
        else:
            if self.parameters['restrictions'] == 'completed_today': # this will look for uploads that happened less than 24h ago
                restrictions = (self.schema.Session() & (self.schema.UploadJob() & 'job_status = "COMPLETED"') & 'session_datetime > DATE_SUB(CURDATE(), INTERVAL 24 HOUR)').proj().fetch(as_dict = True)
            else:
                restrictions = eval(self.parameters['restrictions'])
            exec(f'{self.parameters["table"]}.populate({restrictions}, suppress_errors={suppress_errors}, processes = {processes}, display_progress = True)')

SpksCompute

Bases: BaseCompute

Source code in labdata/compute/ephys.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
class SpksCompute(BaseCompute):
    container = 'labdata-spks'
    cuda = True
    ec2 = dict(small = dict(instance_type = 'g4dn.2xlarge'),   # 8 cpus, 32 GB mem, 200 GB nvme, 1 gpu
               large = dict(instance_type = 'g6.4xlarge',
                            availability_zone = 'us-west-2b')) # 16 cpus, 64 GB mem, 600 GB nvme, 1 gpu
    name = 'spks'
    url = 'http://github.com/spkware/spks'
    def __init__(self,job_id, project = None, allow_s3 = None,  **kwargs):
        '''
#1) find the files
#2) copy just the file you need to scratch
#3) run spike sorting on that file/folder
#4) delete the raw files
#5) repeat until all probes are processed.
        '''
        super(SpksCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3)
        self.file_filters = ['.ap.']
        # default parameters
        self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
                               motion_correction = 1,
                               low_pass = 300.,
                               high_pass = 13000.)
        self.parameter_keys = ['motion_correction','low_pass','high_pass','thresholds','remove_cross_duplicates',
                               'waveforms_from_input','dredge']
        # the parameters that go on the SpikeSortingParams
        self.use_hdf5 = True  # flag to use h5py or zarr format for the waveforms.
        self.parameter_set_num = None # identifier in SpikeSortingParams
        self._init_job()
        if type(self.dataset_key) is dict:
            self.dataset_key = [self.dataset_key] # make it a list
        if not self.job_id is None:
            self.add_parameter_key()

    def _get_parameter_number(self):
        parameter_set_num = None
        # check if in spike sorting
        parameters = pd.DataFrame(self.schema.SpikeSortingParams().fetch())
        filtered_par = {k:self.parameters[k] for k in self.parameter_keys if k in self.parameters.keys()}  
        for i,r in parameters.iterrows():
            # go through every algo
            if filtered_par == json.loads(r.parameters_dict) and self.parameters['algorithm_name'] == r['algorithm_name']:
                parameter_set_num = r.parameter_set_num
        if parameter_set_num is None:
            if len(parameters) == 0:
                parameter_set_num = 1
            else:
                parameter_set_num = np.max(parameters.parameter_set_num.values)+1
            print(f'  --> Using parameter set num {parameter_set_num}')
        return parameter_set_num,parameters

    def add_parameter_key(self):
        parameter_set_num, parameters = self._get_parameter_number()
        if not parameter_set_num in parameters.parameter_set_num.values:
            filtered_par = {k:self.parameters[k] for k in self.parameter_keys if k in self.parameters.keys()}
            self.schema.SpikeSortingParams().insert1(dict(parameter_set_num = parameter_set_num,
                                               algorithm_name = self.parameters['algorithm_name'],
                                               parameters_dict = json.dumps(filtered_par),
                                               code_link = self.url),
                                          skip_duplicates=True)
        self.parameter_set_num = parameter_set_num
        recordings = self.schema.EphysRecording.ProbeSetting() & self.dataset_key
        sortings = self.schema.SpikeSorting() & [dict(d, parameter_set_num = self.parameter_set_num) for d in self.dataset_key]
        if len(recordings) == len(sortings):
            self.set_job_status(
                job_status = 'FAILED',
                job_waiting = 0,
                job_log = f'{self.dataset_key[0]} was already sorted with parameters {self.parameter_set_num}.')    
            raise(ValueError(f'{self.dataset_key[0]} was already sorted with parameters {self.parameter_set_num}.'))

    def _secondary_parse(self,arguments,parameter_number = None):
        '''
        Handles parsing the command line interface
        '''
        if not parameter_number is None:
            self.parameters = json.loads((self.schema.SpikeSortingParams() & f'parameter_set_num = {parameter_number}').fetch('parameters_dict'))
            self.parameters['algorithm_name'] = ((self.schema.SpikeSortingParams() & f'parameter_set_num = {parameter_number}')).fetch1('algorith_name')
            if not len(self.parameters):
                raise(f'Could not find parameter {parameter_number} in SpikeSortingParams.')
            self.parameters = self.parameters[0]
        else:
            import argparse
            parser = argparse.ArgumentParser(
                description = 'Spike sorting using kilosort through the spks package.',
                usage = 'spks -a <SUBJECT> -s <SESSION> -- <PARAMETERS>')

            parser.add_argument('-p','--probe',
                                action='store', default=None, type = int,
                                help = "THIS DOES NOTHING NOW. WILL BE FOR OPENING PHY")
            parser.add_argument('-m','--method',action='store',default = 'ks4.0',type = str,
                                help = 'Method for spike sorting [Kilosort] ks2.5, ks3.0, ks4.0 or [MountainSort] ms5 (default ks4.0)')
            parser.add_argument('-l','--low-pass',
                                action='store', default=self.parameters['low_pass'], type = float,
                                help = "Lowpass filter (default 300.Hz)")
            parser.add_argument('-i','--high-pass',
                                action='store', default=self.parameters['high_pass'], type = float,
                                help = "Highpass filter (default 13000.Hz)")
            parser.add_argument('-t','--thresholds',
                                action='store', default=None, type = float, nargs = 2,
                                help = "Thresholds for spike detection default depends on method.")
            parser.add_argument('-n','--motion-correction',
                                action='store', default = 1, type = int,
                                help = "Motion correction (0 to disable, 1 for rigid, 2+ for blockwise)")
            parser.add_argument('-c','--remove_cross-unit-duplicates',
                                action='store_true', default = False,
                                help = "Skip removing duplicates across units.")
            parser.add_argument('--waveforms-from-sorter',
                                action='store_true', default = False,
                                help = "Extract the waveforms from the file processed by the sorter.")
            parser.add_argument('--interval',
                                action='store', default = None, nargs=2, type = float,
                                help = "Interval to sort in seconds: start end")
            parser.add_argument('--dredge',
                                action='store_true', default = False, 
                                help = "Motion correction using dredge")

            args = parser.parse_args(arguments[1:])
            if 'ks2.5' in  args.method: # defaults for ks2.5
                self.parameters = dict(algorithm_name = 'spks_kilosort2.5',
                                    motion_correction = args.motion_correction>0,
                                    low_pass = args.low_pass,
                                    high_pass = args.high_pass,
                                    thresholds = [9.,3.],
                                    remove_cross_duplicates = args.remove_cross_unit_duplicates)
            elif 'ks3.0' in  args.method: # defaults for ks3.0
                self.parameters = dict(algorithm_name = 'spks_kilosort3.0',
                                    motion_correction = args.no_motion_correction>0,
                                    low_pass = args.low_pass,
                                    high_pass = args.high_pass,
                                    thresholds = [9.,9.],
                                    remove_cross_duplicates = args.remove_cross_unit_duplicates)
            elif 'ks4.0' in  args.method: # defaults for ks4.0
                self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
                                    motion_correction = args.motion_correction,
                                    low_pass = args.low_pass,
                                    high_pass = args.high_pass,
                                    thresholds = [9.,8.],
                                    remove_cross_duplicates = args.remove_cross_unit_duplicates)
            else:
                raise(NotImplemented(f'{args.method} not implemented.'))
            if args.waveforms_from_sorter:
                # default is to extract from the original input file
                self.parameters['waveforms_from_input'] = False 
            if args.dredge:
                # default is to extract from the original input file
                self.parameters['dredge'] = True 
            print(args)
            if not args.thresholds is None:
                self.parameters['thresholds'] = args.thresholds
            if self.parameters['motion_correction'] < 2:
                self.parameters['motion_correction'] = self.parameters['motion_correction']
        self.probe = args.probe
        if not self.probe is None:
            self.parameters['probe'] = int(self.probe) # submit a single probe

    def find_datasets(self, subject_name = None, session_name = None):
        '''
        Searches for subjects and sessions in EphysRecording
        '''
        if subject_name is None and session_name is None:
            print("\n\nPlease specify a 'subject_name' and a 'session_name' to perform spike-sorting.\n\n")

        parameter_set_num, parameters = self._get_parameter_number()
        keys = []
        if not subject_name is None:
            if len(subject_name) > 1:
                raise ValueError(f'Please submit one subject at a time {subject_name}.')
            if not subject_name[0] == '':
                subject_name = subject_name[0]
        if not session_name is None:
            for s in session_name:
                if not s == '':
                    keys.append(dict(subject_name = subject_name,
                                     session_name = s))
        else:
            # find all sessions that can be spike sorted
            sessions = np.unique(((
                (self.schema.EphysRecording() & f'subject_name = "{subject_name}"') -
                (self.schema.SpikeSorting() & f'parameter_set_num = {parameter_set_num}'))).fetch('session_name'))
            for ses in sessions:
                keys.append(dict(subject_name = subject_name,
                                 session_name = ses))
        datasets = []
        for k in keys:
            datasets += (self.schema.EphysRecording()& k).proj('subject_name','session_name','dataset_name').fetch(as_dict = True)

        if not parameter_set_num is None:
            datasets = (self.schema.EphysRecording & ((self.schema.EphysRecording.ProbeSetting() & datasets) -
                        (self.schema.SpikeSorting() & f'parameter_set_num = {parameter_set_num}'))).proj(
                            'subject_name',
                            'session_name',
                            'dataset_name').fetch(as_dict = True)
        print(keys,parameter_set_num)
        return datasets

    def place_tasks_in_queue(self,datasets,task_cmd = None, force_submit = False, multisession = False):
        # this is a special place tasks so we can submit a compute task per probe (when there are multiple probes per dataset)
        if not 'probe' in self.parameters.keys():
            probes = np.unique((self.schema.EphysRecording.ProbeSetting & datasets).fetch('probe_num'))
            jobids = []
            for probe in probes:
                p = dict(self.parameters,probe = int(probe))
                jobids.extend(self._place_tasks_in_queue(datasets,
                                           task_cmd = task_cmd, 
                                           force_submit = force_submit, 
                                           multisession = multisession, 
                                           parameters = p))
            return jobids
        else:
            return self._place_tasks_in_queue(datasets,
                                              task_cmd = task_cmd, 
                                              force_submit = force_submit, 
                                              multisession = multisession, 
                                              parameters = self.parameters)
    def _compute(self):
        print(self.parameters)
        try:
            import matplotlib
            matplotlib.use('Agg') # launch with Agg background because tkinter can cause crashes in strange situations
        except:
            print('Could not set matplotlib Agg engine.')
            pass
        # this performs the actual spike sorting.
        datasets = pd.DataFrame((self.schema.EphysRecording.ProbeFile() & self.dataset_key).fetch())
        # check if a probe was selected
        if 'probe' in self.parameters.keys():
            datasets = datasets[datasets.probe_num.values == self.parameters['probe']]
        for probe_num in np.unique(datasets.probe_num):
            self.set_job_status(job_log = f'Sorting {probe_num}')
            files = datasets[datasets.probe_num.values == probe_num]
            dset = []
            for i,f in files.iterrows():
                if 'ap.cbin' in f.file_path or 'ap.ch' in f.file_path:
                    dset.append(i)
                elif 'ap.meta' in f.file_path: # requires a metadata file (spikeglx)
                    dset.append(i)
            dset = files.loc[dset]
            if not len(dset):
                print(files)
                raise(ValueError(f'Could not find ap.cbin files for probe {probe_num}'))
            localfiles = self.get_files(dset, allowed_extensions = ['.ap.bin'])
            probepath = list(filter(lambda x: str(x).endswith('bin'),localfiles))
            if 'kilosort' in self.parameters['algorithm_name']:
                from spks.sorting import run_kilosort
            dredge_motion_correction = False
            if 'dredge' in self.parameters.keys():
                dredge_motion_correction = self.parameters['dredge']
            # run kilosort using spks for preprocessing.
            if self.parameters['algorithm_name'] == 'spks_kilosort2.5':      
                results_folder = run_kilosort(version = '2.5',sessionfiles = probepath,
                                              temporary_folder = prefs['scratch_path'],
                                              do_post_processing = False,
                                              motion_correction = self.parameters['motion_correction'],
                                              thresholds = self.parameters['thresholds'],
                                              lowpass = self.parameters['low_pass'],
                                              highpass = self.parameters['high_pass'],
                                              dredge_motion_correction = dredge_motion_correction)
            elif self.parameters['algorithm_name'] == 'spks_kilosort3.0':      
                results_folder = run_kilosort(version = '3.0',
                                              sessionfiles = probepath,
                                              temporary_folder = prefs['scratch_path'],
                                              do_post_processing = False,
                                              motion_correction = self.parameters['motion_correction'],
                                              thresholds = self.parameters['thresholds'],
                                              lowpass = self.parameters['low_pass'],
                                              highpass = self.parameters['high_pass'],
                                              dredge_motion_correction = dredge_motion_correction)
            elif self.parameters['algorithm_name'] == 'spks_kilosort4.0':      
                results_folder = run_kilosort(version = '4.0',
                                              sessionfiles = probepath,
                                              temporary_folder = prefs['scratch_path'],
                                              do_post_processing = False,
                                              motion_correction = self.parameters['motion_correction'],
                                              thresholds = self.parameters['thresholds'],
                                              lowpass = self.parameters['low_pass'],
                                              highpass = self.parameters['high_pass'],
                                              dredge_motion_correction = dredge_motion_correction)
            elif self.parameters['algorithm_name'] == 'spks_mountainsort5':
                raise(NotImplemented(f"[{self.name} job] - Algorithm {self.parameters['algorithm_name']} not implemented."))
            else:
                raise(NotImplemented(f"[{self.name} job] - Algorithm {self.parameters['algorithm_name']} not implemented."))
            self.set_job_status(job_log = f'Probe {probe_num} sorted, running post-processing.')
            try: # attempt to close all figures before using joblib 
                import pylab as plt
                plt.close('all')
            except:
                pass
            self.postprocess_and_insert(results_folder,
                                        probe_num = probe_num,
                                        remove_duplicates = True,
                                        n_pre_samples = 45)
            self.unregister_safe_exit() # in case these get triggered by shutdown
            try:
                from joblib.externals.loky import get_reusable_executor
                get_reusable_executor().shutdown(wait=True)

            except:
                print(f'[{self.name} job] Tried to clear joblib Loky executers and failed.')
            self.register_safe_exit() # put it back..

            if not self.keep_intermediate:
                # delete results_folder
                print(f'[{self.name} job] Removing the results folder {results_folder}.')
                import shutil
                shutil.rmtree(results_folder)
                # delete local files if they did not exist
                if not self.files_existed:
                    for f in localfiles:
                        os.unlink(f)
            else:
                # send the files to AWS # this will skip the raw data binary files
                dataset = dict(**self.dataset_key[0])
                dataset['dataset_name'] = f'spike_sorting/imec{probe_num}/{self.parameter_set_num}/{self.parameters["algorithm_name"]}'
                src = [a for a in Path(results_folder).rglob('*') if not (str(a).endswith('.bin') or 
                                                                          str(a).endswith('.dat') or
                                                                          str(a).endswith('.hdf5'))] # get the files minus the .bin/.dat
                new_src = [Path(prefs['local_paths'][0])/a for a in self.schema.AnalysisFile().generate_filepaths(src,dataset)]
                from tqdm.auto import tqdm
                import shutil
                def move(s,d):
                    Path(d).parent.mkdir(exist_ok = True, parents = True) # create the directory if not there, we do this for every file because they might be in diff folders.
                    shutil.move(s,d)
                res = Parallel(n_jobs=DEFAULT_N_JOBS)(delayed(move)(s,n) for s,n in tqdm(zip(src,new_src),total = len(src),desc = "Copying to intermediate files to local_path"))
                # copy to a local path
                #print(f'[{self.name} job] Moving the temporary files from {results_folder} to {kept_results_folder}')
                filekeys = self.schema.AnalysisFile().upload_files(new_src,dataset)

                base_key = dict((self.schema.EphysRecording & self.dataset_key[0]).proj().fetch1(),
                                probe_num = probe_num,
                                parameter_set_num = self.parameter_set_num)
                self.schema.SpikeSorting.IntermediateFiles.insert([dict(base_key,**k) for k in filekeys])
                shutil.rmtree(results_folder)
                print(f'[{self.name} job] Removing the results folder {results_folder}. Intermediate results kept and uploaded to AnalysisFile.')


    def prepare_results(self,results_folder,
                        probe_num,
                        remove_duplicates,
                        n_pre_samples):

        from spks import Clusters
        if remove_duplicates:
            clu = Clusters(results_folder, get_waveforms = False, get_metrics = False)
            clu.remove_duplicate_spikes(
                overwrite_phy = True,
                remove_cross_duplicates = self.parameters['remove_cross_duplicates']) 
            del clu
        clu = Clusters(results_folder, get_waveforms = False, get_metrics = False)
        clu.compute_template_amplitudes_and_depths()
        base_keys = []
        ssdict = [] # sortings
        udict = [] # unit
        featurestosave = [] # features
        events = [] # events for sync

        for ifile,(dataset_key,(o,e)) in enumerate(zip(self.dataset_key,clu.metadata['file_offsets'])):
            base_keys.append(dict((self.schema.EphysRecording & dataset_key).proj().fetch1(),
                                  probe_num = probe_num,
                                  parameter_set_num = self.parameter_set_num))
            # this handles multi-session results
            idx = np.where((clu.spike_times>=o) & (clu.spike_times<e))[0]
            spikes = clu.spike_times[idx].flatten().astype(np.uint64) - o # subtract start of session
            clusters = clu.spike_clusters[idx]
            amplitudes = clu.spike_amplitudes[idx].flatten().astype(np.float32)
            positions = clu.spike_positions[idx,:].astype(np.float32)

            ssdict.append(dict(base_keys[-1],
                               n_pre_samples = n_pre_samples,
                               n_sorted_units = len(np.unique(clusters)),
                               n_detected_spikes = len(spikes),
                               sorting_datetime = datetime.fromtimestamp(
                                   Path(results_folder).stat().st_ctime),
                               sorting_channel_indices = clu.channel_map.flatten(),
                               sorting_channel_coords = clu.channel_positions))
            udict.append([]) # list of lists
            for iclu in np.unique(clusters):
                subidx = np.where(clusters == iclu)[0]
                udict[-1].append(dict(
                    base_keys[-1],unit_id = iclu,
                    spike_positions = positions[subidx],
                    spike_times = spikes[subidx],
                    spike_amplitudes = amplitudes[subidx]))

            featurestosave.append(dict(template_features = clu.spike_pc_features[idx].astype(np.float32),
                                       spike_templates = clu.spike_templates[idx],
                                       cluster_indices = clusters,
                                       whitening_matrix = clu.whitening_matrix,
                                       templates = clu.templates, # maybe should save only existing templates?
                                       template_feature_ind = clu.template_pc_features_ind))
            k = f'file{ifile}_sync_onsets'
            stream_name = f'imec{probe_num}' # to save the events and files
            events.append([])
            if k in clu.metadata.keys():
                for ik in clu.metadata[k].keys():
                    ev = clu.metadata[k][ik].astype(np.uint64)
                    events[-1].append(dict((self.schema.Dataset & base_keys[-1]).proj().fetch1(), 
                                       stream_name = stream_name,
                                       event_name = str(k),
                                       event_timestamps = ev))
        return clu,base_keys,ssdict, udict, featurestosave, events

    def postprocess_and_insert(self,
                               results_folder,
                               probe_num,
                               remove_duplicates = True,
                               n_pre_samples = 45):
        '''Does the preprocessing for a spike sorting and inserts'''
        # get the results in a dictionary and remove duplicates
        clu,base_keys,ssdicts, udicts, featurestosave, events = self.prepare_results(results_folder,
                                                                                     probe_num,
                                                                                     remove_duplicates,
                                                                                     n_pre_samples)
        # save the features to a file, will take around 2 min for a 1h recording.
        if not featurestosave[0]['template_features'] is None:
            for idset,features in enumerate(featurestosave):
                save_dict_to_h5(Path(results_folder)/f'features{idset}.hdf5',features)
        n_jobs = DEFAULT_N_JOBS  # gets the default number of jobs from labdata
        # extract the waveforms from the binary file
        n_jobs_wave = n_jobs
        for idset,(base_key,dataset_key,udict,ssdict,offsets,event) in enumerate(zip(
            base_keys,
            self.dataset_key,
            udicts,
            ssdicts,
            clu.metadata['file_offsets'],
            events)):
            if len(udict) > 800:
                n_jobs_wave = 3 # to prevent running out of memory when collecting waveforms
            udict, binaryfile, nchannels,res = self.extract_waveforms(udict,
                                                                      clu.metadata['nchannels'],
                                                                      results_folder,
                                                                      n_pre_samples,
                                                                      n_jobs_wave,
                                                                      offsets = offsets)
            def median_waves(r,gains):
                if not r is None:
                    return np.median(r.astype(np.float32),axis = 0)*gains
                else:
                    return None
            waves_dict = []
            extras = dict(compression = 'gzip',
                          compression_opts = 1,
                          chunks = True, 
                          shuffle = True)
            from tqdm.auto import tqdm
            # print(f'Collecting waveforms and saving for dataset:{idset}.')
            # save these to zarr to be compressed faster?
            if self.use_hdf5: # zarr not implemented yet.
                import h5py as h5
                with h5.File(Path(results_folder)/f'waveforms{idset}.hdf5','w') as wavefile:
                    for u,w in tqdm(zip(udict,res), total = len(udict),
                                    desc = f'Saving waveforms to file [dataset {idset}]'):
                        m = median_waves(w, gains = clu.channel_gains)
                        if not w is None:
                            waves_dict.append(dict(base_key,
                                                    unit_id = u['unit_id'],
                                                    waveform_median = m))
                            # save to the file
                            wavefile.create_dataset(str(u['unit_id'])+'/waveforms',data = w,**extras)
                            wavefile.create_dataset(str(u['unit_id'])+'/indices',data = u['waveform_indices'],**extras)
                        else:
                            print(f"Unit {u['unit_id']} had no spikes extracted (dataset {idset})")
            stream_name = f'imec{probe_num}' # to save the events and files, also defined elsewhere (fix that on re-factor)
            src = [Path(results_folder)/f'waveforms{idset}.hdf5',Path(results_folder)/f'features{idset}.hdf5']
            dataset = dict(**dataset_key)
            dataset['dataset_name'] = f'spike_sorting/{stream_name}/{self.parameter_set_num}'

            filekeys = self.schema.AnalysisFile().upload_files(src,dataset)
            ssdict['waveforms_file'] = filekeys[0]['file_path']
            ssdict['waveforms_storage'] = filekeys[0]['storage']
            if not featurestosave[idset]['template_features'] is None:
                ssdict['features_file'] = filekeys[1]['file_path']
                ssdict['features_storage'] = filekeys[1]['storage']
            # insert the syncs
            if len(event):
                # Add stream
                print(f'Inserting {len(event)} events.')
                self.schema.DatasetEvents.insert1(dict(dataset_key,
                                                       stream_name = stream_name),
                                                  skip_duplicates = True,
                                                  allow_direct_insert = True)
                self.schema.DatasetEvents.Digital.insert(event,
                                                         skip_duplicates = True,
                                                         allow_direct_insert = True)

            # inserts
            import logging
            logging.getLogger('datajoint').setLevel(logging.WARNING)
            # these can't be done in a safe way quickly so if they fail we have delete SpikeSorting
            self.schema.SpikeSorting.insert1(ssdict,skip_duplicates = True)
            # Insert datajoint in parallel.
            parallel_insert(self.schema.schema_project,'SpikeSorting.Unit',udict, n_jobs = DEFAULT_N_JOBS,
                            skip_duplicates = True, ignore_extra_fields = True)
            parallel_insert(self.schema.schema_project,'SpikeSorting.Waveforms',waves_dict, n_jobs = DEFAULT_N_JOBS,
                            skip_duplicates = True, ignore_extra_fields = True)

            # Add a segment from a random location.
            from spks.io import map_binary
            dat = map_binary(binaryfile, nchannels = nchannels)
            nsamples = int(clu.sampling_rate*2)
            offset_samples = int(np.random.uniform(offsets[0]+nsamples, offsets[1]-nsamples-1))
            self.schema.SpikeSorting.Segment.insert1(dict(base_key,
                                            segment_num = 1,
                                            offset_samples = offset_samples - offsets[0],
                                            segment = np.array(dat[offset_samples : offset_samples + nsamples])))
            del dat
        self.set_job_status(job_log = f'Completed {base_key}')
        from labdata.schema import UnitMetrics
        # limit number of jobs because of memory constraints
        self.schema.UnitMetrics.populate(base_key, display_progress = True, processes = int(max(1,np.ceil(n_jobs/2))))

    def extract_waveforms(self,udict, nchannels, results_folder, n_pre_samples,n_jobs, offsets):
        # extract the waveforms
        from spks.io import map_binary
        # if not 'waveforms_from_sorter' in self.parameters.keys(): 
        #     self.parameters['waveforms_from_sorter'] = False
        # if not self.parameters['waveforms_from_sorter']:    
        binaryfile = list(Path(results_folder).glob("filtered_recording*.bin"))[0]
        # else: # not implemented
        #     binaryfile = list(Path(results_folder).glob("temp_wh.dat"))[0]
        #     nchannels = clu.metadata['nchannels']
        dat = map_binary(binaryfile, nchannels = nchannels) # to get the duration

        udict = select_random_waveforms(udict, 
                                        wpre = n_pre_samples, 
                                        wpost = n_pre_samples,
                                        duration = offsets[-1]-offsets[0])
        del dat
        res = get_waveforms_from_binary(binaryfile, nchannels,
                                        [u['waveform_indices']+offsets[0] for u in udict],
                                        wpre = n_pre_samples,
                                        wpost = n_pre_samples,
                                        n_jobs = n_jobs)
        return udict, binaryfile, nchannels, res

__init__(job_id, project=None, allow_s3=None, **kwargs)

1) find the files

2) copy just the file you need to scratch

3) run spike sorting on that file/folder

4) delete the raw files

5) repeat until all probes are processed.

Source code in labdata/compute/ephys.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    def __init__(self,job_id, project = None, allow_s3 = None,  **kwargs):
        '''
#1) find the files
#2) copy just the file you need to scratch
#3) run spike sorting on that file/folder
#4) delete the raw files
#5) repeat until all probes are processed.
        '''
        super(SpksCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3)
        self.file_filters = ['.ap.']
        # default parameters
        self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
                               motion_correction = 1,
                               low_pass = 300.,
                               high_pass = 13000.)
        self.parameter_keys = ['motion_correction','low_pass','high_pass','thresholds','remove_cross_duplicates',
                               'waveforms_from_input','dredge']
        # the parameters that go on the SpikeSortingParams
        self.use_hdf5 = True  # flag to use h5py or zarr format for the waveforms.
        self.parameter_set_num = None # identifier in SpikeSortingParams
        self._init_job()
        if type(self.dataset_key) is dict:
            self.dataset_key = [self.dataset_key] # make it a list
        if not self.job_id is None:
            self.add_parameter_key()

find_datasets(subject_name=None, session_name=None)

Searches for subjects and sessions in EphysRecording

Source code in labdata/compute/ephys.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def find_datasets(self, subject_name = None, session_name = None):
    '''
    Searches for subjects and sessions in EphysRecording
    '''
    if subject_name is None and session_name is None:
        print("\n\nPlease specify a 'subject_name' and a 'session_name' to perform spike-sorting.\n\n")

    parameter_set_num, parameters = self._get_parameter_number()
    keys = []
    if not subject_name is None:
        if len(subject_name) > 1:
            raise ValueError(f'Please submit one subject at a time {subject_name}.')
        if not subject_name[0] == '':
            subject_name = subject_name[0]
    if not session_name is None:
        for s in session_name:
            if not s == '':
                keys.append(dict(subject_name = subject_name,
                                 session_name = s))
    else:
        # find all sessions that can be spike sorted
        sessions = np.unique(((
            (self.schema.EphysRecording() & f'subject_name = "{subject_name}"') -
            (self.schema.SpikeSorting() & f'parameter_set_num = {parameter_set_num}'))).fetch('session_name'))
        for ses in sessions:
            keys.append(dict(subject_name = subject_name,
                             session_name = ses))
    datasets = []
    for k in keys:
        datasets += (self.schema.EphysRecording()& k).proj('subject_name','session_name','dataset_name').fetch(as_dict = True)

    if not parameter_set_num is None:
        datasets = (self.schema.EphysRecording & ((self.schema.EphysRecording.ProbeSetting() & datasets) -
                    (self.schema.SpikeSorting() & f'parameter_set_num = {parameter_set_num}'))).proj(
                        'subject_name',
                        'session_name',
                        'dataset_name').fetch(as_dict = True)
    print(keys,parameter_set_num)
    return datasets

postprocess_and_insert(results_folder, probe_num, remove_duplicates=True, n_pre_samples=45)

Does the preprocessing for a spike sorting and inserts

Source code in labdata/compute/ephys.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def postprocess_and_insert(self,
                           results_folder,
                           probe_num,
                           remove_duplicates = True,
                           n_pre_samples = 45):
    '''Does the preprocessing for a spike sorting and inserts'''
    # get the results in a dictionary and remove duplicates
    clu,base_keys,ssdicts, udicts, featurestosave, events = self.prepare_results(results_folder,
                                                                                 probe_num,
                                                                                 remove_duplicates,
                                                                                 n_pre_samples)
    # save the features to a file, will take around 2 min for a 1h recording.
    if not featurestosave[0]['template_features'] is None:
        for idset,features in enumerate(featurestosave):
            save_dict_to_h5(Path(results_folder)/f'features{idset}.hdf5',features)
    n_jobs = DEFAULT_N_JOBS  # gets the default number of jobs from labdata
    # extract the waveforms from the binary file
    n_jobs_wave = n_jobs
    for idset,(base_key,dataset_key,udict,ssdict,offsets,event) in enumerate(zip(
        base_keys,
        self.dataset_key,
        udicts,
        ssdicts,
        clu.metadata['file_offsets'],
        events)):
        if len(udict) > 800:
            n_jobs_wave = 3 # to prevent running out of memory when collecting waveforms
        udict, binaryfile, nchannels,res = self.extract_waveforms(udict,
                                                                  clu.metadata['nchannels'],
                                                                  results_folder,
                                                                  n_pre_samples,
                                                                  n_jobs_wave,
                                                                  offsets = offsets)
        def median_waves(r,gains):
            if not r is None:
                return np.median(r.astype(np.float32),axis = 0)*gains
            else:
                return None
        waves_dict = []
        extras = dict(compression = 'gzip',
                      compression_opts = 1,
                      chunks = True, 
                      shuffle = True)
        from tqdm.auto import tqdm
        # print(f'Collecting waveforms and saving for dataset:{idset}.')
        # save these to zarr to be compressed faster?
        if self.use_hdf5: # zarr not implemented yet.
            import h5py as h5
            with h5.File(Path(results_folder)/f'waveforms{idset}.hdf5','w') as wavefile:
                for u,w in tqdm(zip(udict,res), total = len(udict),
                                desc = f'Saving waveforms to file [dataset {idset}]'):
                    m = median_waves(w, gains = clu.channel_gains)
                    if not w is None:
                        waves_dict.append(dict(base_key,
                                                unit_id = u['unit_id'],
                                                waveform_median = m))
                        # save to the file
                        wavefile.create_dataset(str(u['unit_id'])+'/waveforms',data = w,**extras)
                        wavefile.create_dataset(str(u['unit_id'])+'/indices',data = u['waveform_indices'],**extras)
                    else:
                        print(f"Unit {u['unit_id']} had no spikes extracted (dataset {idset})")
        stream_name = f'imec{probe_num}' # to save the events and files, also defined elsewhere (fix that on re-factor)
        src = [Path(results_folder)/f'waveforms{idset}.hdf5',Path(results_folder)/f'features{idset}.hdf5']
        dataset = dict(**dataset_key)
        dataset['dataset_name'] = f'spike_sorting/{stream_name}/{self.parameter_set_num}'

        filekeys = self.schema.AnalysisFile().upload_files(src,dataset)
        ssdict['waveforms_file'] = filekeys[0]['file_path']
        ssdict['waveforms_storage'] = filekeys[0]['storage']
        if not featurestosave[idset]['template_features'] is None:
            ssdict['features_file'] = filekeys[1]['file_path']
            ssdict['features_storage'] = filekeys[1]['storage']
        # insert the syncs
        if len(event):
            # Add stream
            print(f'Inserting {len(event)} events.')
            self.schema.DatasetEvents.insert1(dict(dataset_key,
                                                   stream_name = stream_name),
                                              skip_duplicates = True,
                                              allow_direct_insert = True)
            self.schema.DatasetEvents.Digital.insert(event,
                                                     skip_duplicates = True,
                                                     allow_direct_insert = True)

        # inserts
        import logging
        logging.getLogger('datajoint').setLevel(logging.WARNING)
        # these can't be done in a safe way quickly so if they fail we have delete SpikeSorting
        self.schema.SpikeSorting.insert1(ssdict,skip_duplicates = True)
        # Insert datajoint in parallel.
        parallel_insert(self.schema.schema_project,'SpikeSorting.Unit',udict, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)
        parallel_insert(self.schema.schema_project,'SpikeSorting.Waveforms',waves_dict, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)

        # Add a segment from a random location.
        from spks.io import map_binary
        dat = map_binary(binaryfile, nchannels = nchannels)
        nsamples = int(clu.sampling_rate*2)
        offset_samples = int(np.random.uniform(offsets[0]+nsamples, offsets[1]-nsamples-1))
        self.schema.SpikeSorting.Segment.insert1(dict(base_key,
                                        segment_num = 1,
                                        offset_samples = offset_samples - offsets[0],
                                        segment = np.array(dat[offset_samples : offset_samples + nsamples])))
        del dat
    self.set_job_status(job_log = f'Completed {base_key}')
    from labdata.schema import UnitMetrics
    # limit number of jobs because of memory constraints
    self.schema.UnitMetrics.populate(base_key, display_progress = True, processes = int(max(1,np.ceil(n_jobs/2))))

DeeplabcutCompute

Bases: BaseCompute

Source code in labdata/compute/pose.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class DeeplabcutCompute(BaseCompute):
    container = 'labdata-deeplabcut'
    cuda = True
    name = 'deeplabcut'
    url = 'http://github.com/DeepLabCut/DeepLabCut'
    def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
        '''
        Run deeplabcut on video or train a model
        '''
        super(DeeplabcutCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3)
        self.file_filters = ['.avi','.mov','.mp4','.zarr'] # allowed file extensions..
        # default parameters
        self.parameters = dict(algorithm = 'deeplabcut',
                               mode = None, # select 'train' or 'infer'
                               model_num = None,
                               label_set = None,
                               video_name = None,
                               net_type = 'resnet_50',
                               batch_size = 8,
                               iteractions = 100000)
        self._init_job()
        if not self.job_id is None:
            self.add_parameter_key()

    def add_parameter_key(self):
        model_num, parameters = self._get_parameter_number()
        if self.parameters['mode'] == 'train':
            #if not model_num in parameters.model_num.values:
            #    PoseEstimationModel().insert1(dict(model_num = model_num,
            #                                       pose_label_set_num = self.parameters['label_set'], 
            #                                      algorithm_name = self.name,
            #                                       parameters_dict = json.dumps(self.model_parameters),
            #                                       code_link = self.url),
            #                                  skip_duplicates=True) # these will be updated later
            self.model_num = model_num
        # check here if it was already infered with this model.

    def _get_parameter_number(self):
        self.model_parameters = dict(algorithm = self.parameters['algorithm'],
                                     net_type = self.parameters['net_type'],
                                     batch_size = self.parameters['batch_size'],
                                     iteractions = self.parameters['iteractions'])
        parameter_set_num = None
        parameters = pd.DataFrame(self.schema.PoseEstimationModel().fetch())
        model_num = None
        if self.parameters['mode'] == 'train':
            for i,r in parameters.iterrows():
                # go through every parameter and label_set
                if (self.model_parameters == json.loads(r.parameters_dict) and 
                    self.parameters['model_num'] is None and 
                    self.parameters['pose_label_set_num'] == r['pose_label_set_num']):
                    model_num = r.model_num
            if model_num is None:
                if not self.parameters['model_num'] is None:
                    model_num = self.parameters['model_num']
                elif len(parameters) == 0:
                    model_num = 1
                else:
                    model_num = np.max(parameters.model_num.values)+1
            self.parameters['model_num'] = model_num
            return model_num,parameters
        else:
            return self.parameters['model_num'],parameters

    def _secondary_parse(self,arguments,parameters = None):
        '''
        Handles parsing the command line interface
        '''
        if not parameters is None: # can just pass the parameters
            self.parameters = parameters
        else:
            import argparse
            parser = argparse.ArgumentParser(
                description = 'Pose estimation analysis using DeepLabCut',
                usage = '''
    deeplabcut -a <SUBJECT> -s <SESSION> -- <TRAIN|INFER> <PARAMETERS>

    Example for inference using a trained model (-m 1):

        labdata2 run deeplabcut -a JC131 -s 20231025_194303 -- infer -m 1 -v side_cam            

                ''')

            parser.add_argument('mode',action='store', type = str,
                                help = '[required] Specifies what to do (train or infer)')
            parser.add_argument('-v','--video-name',
                                action='store', type = str, default = None,
                                help = "Select files to analyze (DatasetVideo.video_name)")
            parser.add_argument('-l','--label-set',
                                action='store', default=None, type = int,
                                help = "Label set to run training.")
            parser.add_argument('-m','--model-num',
                                action='store', default=None, type = int,
                                help = "Model number to run inference.")
            parser.add_argument('--net-type',
                                action='store', type = str, default = 'resnet_50',
                                help = "Network to run (has to be in the container - resnet_50; resnet_101)")
            parser.add_argument('-i','--iteractions',
                                action='store', default=300000, type = int,
                                help = "Number of iteractions for training")
            args = parser.parse_args(arguments[1:])
            self.parameters['mode'] = args.mode
            self.parameters['video_name'] = args.video_name
            self.parameters['label_set'] = args.label_set
            self.parameters['model_num'] = args.model_num
            self.parameters['net_type'] = args.net_type
            self.parameters['iteractions'] = args.iteractions
        if 'train' in  self.parameters['mode']:
            if self.parameters['label_set'] is None:
                raise(ValueError('Need to define a label-set to train a model.'))
        else:
            if self.parameters['model_num'] is None:
                raise(ValueError('Need to specify a model.'))
            if not len(self.schema.PoseEstimationModel & f'model_num = {self.parameters["model_num"]}'):
                raise(ValueError(f'Could not find model {self.parameters["model_num"]}'))

    def find_datasets(self, subject_name = None, session_name = None):
        '''
        Searches for subjects and sessions 
        '''
        if self.parameters['mode'] ==  'train':
            # check that the label set exists...
            pose_label_set = (self.schema.PoseEstimationLabelSet() & f'pose_label_set_num = {self.parameters["label_set"]}').fetch()
            return 
        if subject_name is None and session_name is None and self.parameters['mode'] == 'infer':
            raise(ValueError('Need to select a dataset to infer using a deeplabcut model.'))
        keys = []
        if not subject_name is None:
            if len(subject_name) > 1:
                raise ValueError(f'Please submit one subject at a time {subject_name}.')
            if not subject_name[0] == '':
                subject_name = subject_name[0]
        if not session_name is None:
            for s in session_name:
                if not s == '':
                    keys.append(dict(subject_name = subject_name,
                                     session_name = s))
        else:
            raise(NotImplementedError('Specifying no session is not yet implemented'))
        datasets = []
        for k in keys:
            datasets += (self.schema.DatasetVideo()& k).fetch(as_dict = True)
        datasets = [dict(subject_name = d['subject_name'],
                         session_name = d['session_name'],
                         dataset_name = d['dataset_name']) for d in datasets]
        datasets = list({v['session_name']:v for v in datasets}.values())
        return datasets

    def _compute(self):
        import deeplabcut
        if self.parameters['mode'] ==  'train':
            # check that the label set exists...
            print(self.parameters)
            print(f'pose_label_set_num = {self.parameters["label_set"]}')
            pose_label_set = self.schema.PoseEstimationLabelSet() & f'pose_label_set_num = {self.parameters["label_set"]}'
            cfgfile = create_project(self.parameters,self.parameters['model_num'],schema=self.schema)
            #print("Checking the labels.")
            #deeplabcut.check_labels(cfgfile)
            print("Generating the training dataset")
            deeplabcut.create_training_dataset(cfgfile)
            print("Training network")
            deeplabcut.train_network(cfgfile, maxiters = self.parameters['iteractions'])
            # once training completes, create a zip with the model and upload.
            self.schema.PoseEstimationModel().insert_model(self.parameters['model_num'], 
                    model_folder=Path(cfgfile).parent,
                    pose_label_set_num = self.parameters["label_set"],
                    algorithm_name = self.parameters['algorithm'],
                    parameters = self.model_parameters,
                    training_datetime=datetime.now(),
                    container_name = self.container,
                    code_link = self.url)
            # Save to PoseEstimationModel()
        elif self.parameters['mode'] == 'infer':
            # download the model if needed
            cfgfile = (self.schema.PoseEstimationModel() & f'model_num = {self.parameters["model_num"]}').get_model()
            datasets = (self.schema.File() & (self.schema.DatasetVideo.File() & self.dataset_key)).fetch(as_dict = True)
            if not len(datasets):
                raise(ValueError(f"Could not find {self.dataset_key}"))
            if len(datasets) > 1:
                # select the video to analyse
                datasets = (self.schema.File() & (self.schema.DatasetVideo.File() & dict(
                    self.dataset_key,
                    video_name = self.parameters['video_name']))).fetch(as_dict = True)
            localfiles = self.get_files(datasets)
            resfile = deeplabcut.analyze_videos(cfgfile,[str(f) for f in localfiles], videotype='.avi')

            # Save the results to PoseEstimation()
            if len(localfiles)>1:
                print(f'Not sure how to insert multiple files.. Check the inputs {localfiles}.')
            resfile = Path(str(localfiles[0].with_suffix(''))+resfile).with_suffix('.h5') # assuming there is only one file
            bodyparts,xyl = read_dlc_file(resfile)  
            toinsert = []
            for i,b in enumerate(bodyparts):
                toinsert.append(dict(self.dataset_key,
                                     video_name = self.parameters['video_name'],
                                     model_num = self.parameters["model_num"],
                                     label_name = b,
                                     x = xyl[:,i,0],
                                     y = xyl[:,i,1],
                                     likelihood = xyl[:,i,2]))
            self.schema.PoseEstimation().insert(toinsert)

__init__(job_id, project=None, allow_s3=None, **kwargs)

Run deeplabcut on video or train a model

Source code in labdata/compute/pose.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
    '''
    Run deeplabcut on video or train a model
    '''
    super(DeeplabcutCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3)
    self.file_filters = ['.avi','.mov','.mp4','.zarr'] # allowed file extensions..
    # default parameters
    self.parameters = dict(algorithm = 'deeplabcut',
                           mode = None, # select 'train' or 'infer'
                           model_num = None,
                           label_set = None,
                           video_name = None,
                           net_type = 'resnet_50',
                           batch_size = 8,
                           iteractions = 100000)
    self._init_job()
    if not self.job_id is None:
        self.add_parameter_key()

find_datasets(subject_name=None, session_name=None)

Searches for subjects and sessions

Source code in labdata/compute/pose.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def find_datasets(self, subject_name = None, session_name = None):
    '''
    Searches for subjects and sessions 
    '''
    if self.parameters['mode'] ==  'train':
        # check that the label set exists...
        pose_label_set = (self.schema.PoseEstimationLabelSet() & f'pose_label_set_num = {self.parameters["label_set"]}').fetch()
        return 
    if subject_name is None and session_name is None and self.parameters['mode'] == 'infer':
        raise(ValueError('Need to select a dataset to infer using a deeplabcut model.'))
    keys = []
    if not subject_name is None:
        if len(subject_name) > 1:
            raise ValueError(f'Please submit one subject at a time {subject_name}.')
        if not subject_name[0] == '':
            subject_name = subject_name[0]
    if not session_name is None:
        for s in session_name:
            if not s == '':
                keys.append(dict(subject_name = subject_name,
                                 session_name = s))
    else:
        raise(NotImplementedError('Specifying no session is not yet implemented'))
    datasets = []
    for k in keys:
        datasets += (self.schema.DatasetVideo()& k).fetch(as_dict = True)
    datasets = [dict(subject_name = d['subject_name'],
                     session_name = d['session_name'],
                     dataset_name = d['dataset_name']) for d in datasets]
    datasets = list({v['session_name']:v for v in datasets}.values())
    return datasets

CaimanCompute

Bases: Suite2pCompute

Source code in labdata/compute/caiman.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
class CaimanCompute(Suite2pCompute):
    container = 'labdata-caiman'
    cuda = False
    name = 'caiman'
    url = 'http://github.com/flatironinstitute/CaImAn'
    def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
        '''
        This class runs Caiman on a dataset, which can be used for both 1p and 2p data. The ComputeTask will:

        1. **File Identification and Dataset Type Check**: Identify the files and determine the type of dataset (Miniscope or TwoPhoton).
        2. **File Copy to Scratch**: Copy only the necessary file(s) to a scratch folder for processing.
        3. **Caiman Execution**: Execute Caiman on the copied file or folder. Using the parameters specified.
        4. **Cleanup and Result Integration**: Delete the memory-mapped files generated during processing and integrate the results into the CellSegmentation table.

        This class includes a handler for the CLI.
        '''
        super(Suite2pCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3) # takes BaseCompute init
        self.file_filters = ['.zarr.zip'] # this only runs on zarr.zip for the moment.
        # default parameters
        self.parameters = dict(algorithm_name = 'caiman')
        # will only store these in CellSegmentationParams
        self.parameter_keys = ['pw_rigid','p','gSig','gSiz','merge_thr','rf','stride','tsub',
                               'ssub','nb','min_corr','min_pnr','ssub_B','ring_size_factor',
                               'min_SNR', 'rval_thr', 'use_cnn', 'detrendWin','quantileMin','denoise_dff']

        self.parameter_set_num = None # identifier in CellSegmentationParams

        self._init_job()
        if not self.job_id is None:
            self.add_parameter_key()

    def _secondary_parse(self,arguments,parameter_number = None):
        '''
        Handles parsing the command line interface
        '''
        if not parameter_number is None:
            self.parameters = ((self.schema.CellSegmentationParams() & f'parameter_set_num = {parameter_number}')).fetch(as_dict = True)
            if not len(self.parameters):
                raise(f'Could not find parameter {parameter_number} in CellSegmentationParams.')
            self.parameters = self.parameters[0]
        else:
            import argparse
            parser = argparse.ArgumentParser(
                description = 'Segmentation of imaging datasets using CaImAn.',
                usage = 'caiman -a <SUBJECT> -s <SESSION> -- <PARAMETERS>')
            parser.add_argument('-i','--parameter_set',default = None, type=int, help = 'Parameter set number')
            parser.add_argument('-m','--pwrigid',
                                action='store_true', default=False,
                                help = "Piecewise-rigid registration")
            parser.add_argument('-p',
                                action='store', default=1, type = int,
                                help = "Order of the autoregressive system")
            parser.add_argument('-g','--gsig',
                                action='store', default=[6,6], type = int, nargs = 2,
                                help = "Expected halfwidth of the neurons in pixels")
            parser.add_argument('--nb',
                                action='store', default=0, type = int,
                                help = "number of background components (rank) if positive, set to 0 for CNMFE")
            parser.add_argument('-r','--rf',
                                action='store', default=40, type = int,
                                help = "half-size of the patches in pixels. e.g., if rf=40, patches are 80x80")
            parser.add_argument('--stride',
                                action='store', default=20, type = int,
                                help = "size of the overlap in pixels")
            parser.add_argument('-s','--ring-size',
                                action='store', default=1.4, type = float,
                                help = "radius of ring is gSiz*ring_size_factor")
            parser.add_argument('--merge-thr',
                                action='store', default=.7, type = float,
                                help = "merging threshold, max correlation allowed")
            parser.add_argument('--tsub',
                                action='store', default=4, type = int,
                                help = "downsampling factor in time for initialization, increase if you have memory problems")
            parser.add_argument('--ssub',
                                action='store', default=2, type = int,
                                help = "downsampling factor in space for initialization, increase if you have memory problems")
            parser.add_argument('--ssub_b',
                                action='store', default=2, type = int,
                                help = "additional downsampling factor for the background")
            parser.add_argument('--min-corr',
                                action='store', default=.8, type = float,
                                help = "min peak value from correlation image")
            parser.add_argument('--min-pnr',
                                action='store', default=10, type = float,
                                help = "min peak to noise ratio")

            parser.add_argument('--snr_thr',
                                action='store', default=7, type = float,
                                help = "[cell selection] signal to noise ratio threshold")
            parser.add_argument('--rval-thr',
                                action='store', default=0.85, type = float,
                                help = "[cell selection] spatial correlation threshold")
            parser.add_argument('--use-cnn',
                                action='store_true', default = False,
                                help = "[cell selection] use a CNN to help deciding good from bad units")
            parser.add_argument('--denoise-dff',
                                action='store_true', default = False,
                                help = "[df/f] compute dff on the denoised data")

            parser.add_argument('--quantile_min',
                                action='store', default=8, type = float,
                                help = "[df/f] Minimum quantile for df/f detrending")
            parser.add_argument('--detrend_win',
                                action='store', default=250, type = int,
                                help = "[df/f] Number of frames for detrending")
            parser.add_argument('--roi',
                                action='store', default=None, type = int, nargs = 4,
                                help = "ROI")

            args = parser.parse_args(arguments[1:])

            params = dict(pw_rigid = args.pwrigid,
                          p = int(args.p),
                          gSig = [int(a) for a in args.gsig], 
                          gSiz = [int(a) for a in 2*np.array(args.gsig) + 1],
                          merge_thr = float(args.merge_thr),      
                          rf = int(args.rf),                    
                          stride = int(args.stride),              
                          tsub = int(args.tsub),                  
                          ssub = int(args.ssub),                  
                          nb = int(args.nb),                      
                          min_corr = float(args.min_corr),             
                          min_pnr = float(args.min_pnr),               
                          ssub_B = int(args.ssub_b),                 
                          ring_size_factor = float(args.ring_size),
                          min_SNR = float(args.snr_thr),
                          rval_thr = float(args.rval_thr),
                          use_cnn = bool(args.use_cnn),
                          detrendWin = int(args.detrend_win),
                          denoise_dff = bool(args.denoise_dff),
                          quantileMin = float(args.quantile_min),
                          roi = args.roi)   
            self.parameters = params

    def _compute(self):

        from ..stacks import export_to_tiff
        import string
        rand = ''.join(np.random.choice([s for s in string.ascii_lowercase + string.digits],9))
        temporary_folder = Path(prefs['scratch_path'])/f'caiman_temporary_{rand}'
        dset = (self.schema.Miniscope() & self.dataset_key)
        cnmfparams = {
                'motion_correct' : True,
                'method_init': 'corr_pnr',  # use this for 1 photon
                'K': None, # for 1p                                
                'nb': 0,             # number of background components (rank) if positive, set to 0 for CNMFE
                'nb_patch': 0,
                'low_rank_background': None,           # for 1p
                'update_background_components': True,  # sometimes setting to False improve the results
                'del_duplicates': True,                # whether to remove duplicates from initialization
                'normalize_init': False,               # just leave as is
                'center_psf': True,                    # True for 1p
                'only_init': True,    # set it to True to run CNMF-E
                'method_deconvolution': 'oasis'}       # could use 'cvxpy' alternatively

        if not len(dset):
            dset = (self.schema.TwoPhoton() & self.dataset_key)
            # add the 2p parameters here and do the processing per plane.
            is_two_photon = True
        else:
            frame_rate = (self.schema.Miniscope() & self.dataset_key).fetch1('frame_rate')
            is_two_photon = False
        if len((self.schema.File & dset).check_if_files_local()[0]) == 0:
            self.files_existed = False
        dat = dset.open()
        parameters = (self.schema.CellSegmentationParams & f'parameter_set_num = {self.parameter_set_num}').fetch1()

        params = json.loads(parameters['parameters_dict'])
        for k in params.keys():
            cnmfparams[k] = params[k]

        import logging
        logger = logging.getLogger('caiman')
        # Set to logging.INFO if you want much output, potentially much more output
        logger.setLevel(logging.WARNING)
        handler = logging.StreamHandler()
        logger.addHandler(handler)

        import time
        from caiman.source_extraction.cnmf.params import CNMFParams
        parameters = CNMFParams(params_dict = dict({k:cnmfparams[k] for k in ['motion_correct','pw_rigid']},
                                                   fr = frame_rate))

        os.environ['CAIMAN_DATA'] = f'{temporary_folder}'

        paths = export_to_tiff(dat,temporary_folder,
                               crop_region = self.parameters['roi'])
        pmotion = parameters.get_group('motion')

        n_cpus = DEFAULT_N_JOBS
        cluster = setup_cluster(n_cpus)

        tstart = time.time()
        from caiman.motion_correction import MotionCorrect
        mot_correct = MotionCorrect(paths, dview=cluster, **pmotion)
        mot_correct.motion_correct(save_movie=True)
        print(f'Done with motion correction in {(time.time() - tstart)/60.} min.')

        [os.unlink(f) for f in paths]
        fname_mc = mot_correct.fname_tot_els if cnmfparams['pw_rigid'] else mot_correct.fname_tot_rig
        if pmotion['pw_rigid']:
            bord_px = np.ceil(np.maximum(np.max(np.abs(mot_correct.x_shifts_els)),
                                         np.max(np.abs(mot_correct.y_shifts_els)))).astype(int)
        else:
            bord_px = np.ceil(np.max(np.abs(mot_correct.shifts_rig))).astype(int)

        bord_px = 0 if pmotion['border_nan'] == 'copy' else bord_px
        from caiman import save_memmap
        fname_new = save_memmap(fname_mc, base_name='memmap_', order='C',
                                border_to_0 = bord_px)
        [os.unlink(f) for f in fname_mc]

        from caiman import load_memmap
        Yr, dims, T = load_memmap(fname_new)
        images = Yr.T.reshape((T,) + dims, order='F')

        tstart = time.time()
        from caiman.source_extraction import cnmf
        parameters.change_params(cnmfparams)
        cnmfe_model = cnmf.CNMF(n_processes = n_cpus, 
                                dview = cluster, 
                                params = parameters)

        cnmfe_model.fit(images);

        print(f'Done with CNMF in {(time.time() - tstart)/60.} min.')

        quality_params = {'min_SNR': params['min_SNR'],
                          'rval_thr': params['rval_thr'],
                          'use_cnn': params['use_cnn']}
        cnmfe_model.params.change_params(params_dict=quality_params)
        print(f"Evaluating components.")
        cnmfe_model.estimates.evaluate_components(images, cnmfe_model.params, dview=cluster)
        print(f"Computing df/f.")
        cnmfe_model.estimates.detrend_df_f(quantileMin = params['quantileMin'], 
                                           frames_window = int(params['detrendWin']),
                                           flag_auto = False,
                                           use_residuals = not params['denoise_dff'])
        print('*****')
        print(f"Total number of components: {len(cnmfe_model.estimates.C)}")
        print(f"Number accepted: {len(cnmfe_model.estimates.idx_components)}")
        print(f"Number rejected: {len(cnmfe_model.estimates.idx_components_bad)}")
        cluster.terminate()
        from ..stacks import compute_projections
        mean_proj,std_proj,max_proj,corr_proj = compute_projections(images)
        print('Projections computed.')
        import caiman
        if not is_two_photon:
            iplane = 0
            dkey = (self.schema.Miniscope & self.dataset_key).proj().fetch1()
        cell_seg = dict(dkey,
                        parameter_set_num = self.parameter_set_num,
                        algorithm_version = f'caiman {caiman.__version__}',
                        n_rois = cnmfe_model.estimates.A.shape[-1],
                        crop_region = self.parameters['roi'],
                        segmentation_datetime = datetime.now()) # if we need a file to store results, it goes here
        roi_masks = np.array(cnmfe_model.estimates.A.todense()).reshape((*cnmfe_model.dims,-1))
        roi_masks = roi_masks.transpose(2,1,0)

        planekey = dict(dkey,
                        parameter_set_num  = self.parameter_set_num,
                        plane_num = iplane)
        planedict = dict(planekey,
                         plane_n_rois = len(roi_masks),
                         dims = [a for a in mean_proj.shape])
        roidict = []
        tracesdict = []
        rawtracesdict = []
        deconvtracesdict  = []
        selectiondict = []
        for icell,roi in enumerate(roi_masks):
            roi_pixels, roi_pixels_values = get_roi_pixels(roi)
            if len(roi_pixels)>5000:
                print(f"[CaimanCompute] {icell} had an roi with {len(roi_pixels)} pixels. Skipping.")
                continue
            roidict.append(dict(planekey,
                                roi_num = icell,
                                roi_pixels = roi_pixels,
                                roi_pixels_values = roi_pixels_values))
            tracesdict.append(dict(planekey,
                                   roi_num = icell,
                                   dff = cnmfe_model.estimates.F_dff[icell].astype(np.float32)))
            rawtracesdict.append(dict(planekey,
                                      roi_num = icell,
                                      f_trace = cnmfe_model.estimates.C[icell].astype(np.float32)))
            deconvtracesdict.append(dict(planekey,
                                      roi_num = icell,
                                      deconv = cnmfe_model.estimates.S[icell].astype(np.float32)))
            selection = 0
            if icell in cnmfe_model.estimates.idx_components:
                selection = 1
            selectiondict.append(dict(planekey,
                                  roi_num = icell,
                                  selection_method = 'auto',
                                  selection = selection))
        projdict = [dict(planekey,
                         proj_name = n,
                         proj_im = pi) for n,pi in zip(
                             ['mean','std','max','lcorr'],
                             [mean_proj,std_proj,max_proj,corr_proj])]

        self.schema.CellSegmentation.insert1(cell_seg, allow_direct_insert = True)
        self.schema.CellSegmentation.Plane.insert1(planedict,allow_direct_insert = True)
        if not params['pw_rigid']:
            motion = np.array(mot_correct.shifts_rig).astype(np.float32)
            self.schema.CellSegmentation.MotionCorrection.insert1(dict(planekey,
                                                           motion_block_size = 0,
                                                           displacement = motion))
        self.schema.CellSegmentation.Projection.insert(projdict,allow_direct_insert = True)
        self.schema.CellSegmentation.ROI.insert(roidict ,allow_direct_insert = True)
        # Insert traces in parallel to prevent timeout errors from mysql
        from tqdm import tqdm
        parallel_insert(self.schema.schema_project,'CellSegmentation.Traces',tracesdict, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)
        parallel_insert(self.schema.schema_project,'CellSegmentation.RawTraces',rawtracesdict, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)
        parallel_insert(self.schema.schema_project,'CellSegmentation.Deconvolved',deconvtracesdict, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)
        parallel_insert(self.schema.schema_project,'CellSegmentation.Selection',selectiondict, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)                        
        if not self.keep_intermediate:
            print(f'[{self.name} job] Removing the temporary folder.')
            import shutil
            shutil.rmtree(temporary_folder)
            if not self.files_existed:
                for f in (self.schema.File & dset).get():
                    os.unlink(f)
        else:
            print(f'[{self.name} job] Kept the temporary folder {temporary_folder}.')

__init__(job_id, project=None, allow_s3=None, **kwargs)

This class runs Caiman on a dataset, which can be used for both 1p and 2p data. The ComputeTask will:

  1. File Identification and Dataset Type Check: Identify the files and determine the type of dataset (Miniscope or TwoPhoton).
  2. File Copy to Scratch: Copy only the necessary file(s) to a scratch folder for processing.
  3. Caiman Execution: Execute Caiman on the copied file or folder. Using the parameters specified.
  4. Cleanup and Result Integration: Delete the memory-mapped files generated during processing and integrate the results into the CellSegmentation table.

This class includes a handler for the CLI.

Source code in labdata/compute/caiman.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
    '''
    This class runs Caiman on a dataset, which can be used for both 1p and 2p data. The ComputeTask will:

    1. **File Identification and Dataset Type Check**: Identify the files and determine the type of dataset (Miniscope or TwoPhoton).
    2. **File Copy to Scratch**: Copy only the necessary file(s) to a scratch folder for processing.
    3. **Caiman Execution**: Execute Caiman on the copied file or folder. Using the parameters specified.
    4. **Cleanup and Result Integration**: Delete the memory-mapped files generated during processing and integrate the results into the CellSegmentation table.

    This class includes a handler for the CLI.
    '''
    super(Suite2pCompute,self).__init__(job_id, project = project, allow_s3 = allow_s3) # takes BaseCompute init
    self.file_filters = ['.zarr.zip'] # this only runs on zarr.zip for the moment.
    # default parameters
    self.parameters = dict(algorithm_name = 'caiman')
    # will only store these in CellSegmentationParams
    self.parameter_keys = ['pw_rigid','p','gSig','gSiz','merge_thr','rf','stride','tsub',
                           'ssub','nb','min_corr','min_pnr','ssub_B','ring_size_factor',
                           'min_SNR', 'rval_thr', 'use_cnn', 'detrendWin','quantileMin','denoise_dff']

    self.parameter_set_num = None # identifier in CellSegmentationParams

    self._init_job()
    if not self.job_id is None:
        self.add_parameter_key()

Suite2pCompute

Bases: BaseCompute

Source code in labdata/compute/suite2p.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class Suite2pCompute(BaseCompute):
    container = 'labdata-suite2p'
    cuda = False
    name = 'suite2p'
    url = 'http://github.com/mouseland/suite2p'
    def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
        '''
        This class runs Suite2p and FISSA neuropil decontamination on a dataset, which can be used for 2p data. 
        '''
        super(Suite2pCompute,self).__init__(job_id,project = project, allow_s3 = allow_s3)
        self.file_filters = ['.zarr.zip','.sbx','.tiff'] #works with zarr, sbx and tiff
        # default parameters
        self.parameters = dict(algorithm_name = 'suite2p+fissa')
        # will only store these in CellSegmentationParams
        self.parameter_keys = ["use_builtin_classifier","use_fissa","tau","nonrigid",
                               "sparse_mode","connected","spatial_scale","diameter",
                               "threshold_scaling"]

        self.parameter_set_num = None # identifier in CellSegmentationParams

        self._init_job()
        if not self.job_id is None:
            self.add_parameter_key()

    def _get_parameter_number(self):
        parameter_set_num = None
        # check if in spike sorting
        parameters = pd.DataFrame(self.schema.CellSegmentationParams().fetch())
        filtered_par = {k:self.parameters[k] for k in self.parameter_keys}

        for i,r in parameters.iterrows():
            # go through every algo
            if filtered_par == json.loads(r.parameters_dict):
                parameter_set_num = r.parameter_set_num
        if parameter_set_num is None:
            if len(parameters) == 0:
                parameter_set_num = 1
            else:
                parameter_set_num = np.max(parameters.parameter_set_num.values) + 1

        return parameter_set_num,parameters

    def add_parameter_key(self):
        parameter_set_num, parameters = self._get_parameter_number()
        #print(f'Running add_parameter_key {parameters}')
        if not parameter_set_num in parameters.parameter_set_num.values:
            filtered_par = {k:self.parameters[k] for k in self.parameter_keys}  
            self.schema.CellSegmentationParams().insert1(dict(parameter_set_num = parameter_set_num,
                                                  algorithm_name = self.name,
                                                  parameters_dict = json.dumps(filtered_par),
                                                  code_link = self.url),
                                             skip_duplicates=True)
        self.parameter_set_num = parameter_set_num
        # this can be applied to the TwoPhoton or the Miniscope datasets
        if self.dataset_key is None:
            print('dataset_key was not set.')
            return
        if len(self.schema.TwoPhoton.Plane() & self.dataset_key):
            recordings = self.schema.TwoPhoton.Plane() & self.dataset_key
            segmentations = self.schema.CellSegmentation.Plane() & self.dataset_key & dict(parameter_set_num = self.parameter_set_num)
        elif len(self.schema.Miniscope() & self.dataset_key):
            recordings = self.schema.Miniscope() & self.dataset_key
            segmentations = self.schema.CellSegmentation.Plane() & self.dataset_key & dict(parameter_set_num = self.parameter_set_num)
        if len(recordings) == len(segmentations):
            self.set_job_status(
                job_status = 'FAILED',
                job_waiting = 0,
                job_log = f'{self.dataset_key} was already segmented with parameters {self.parameter_set_num}.')    
            raise(ValueError(f'{self.dataset_key} was already segmented with parameters {self.parameter_set_num}.'))

    def _secondary_parse(self,arguments,parameter_number = None):
        '''
        Handles parsing the command line interface
        '''
        if not parameter_number is None:
            self.parameters = ((self.schema.CellSegmentationParams() & f'parameter_set_num = {parameter_number}')).fetch(as_dict = True)
            if not len(self.parameters):
                raise(f'Could not find parameter {parameter_number} in CellSegmentationParams.')
            self.parameters = self.parameters[0]
        else:
            import argparse
            parser = argparse.ArgumentParser(
                description = 'Segmentation of imaging datasets using Suite2p and FISSA neuropil correction.',
                usage = 'suite2p -a <SUBJECT> -s <SESSION> -- <PARAMETERS>')
            parser.add_argument('-i','--parameter_set',default = None, type=int, help = 'Parameter set number')

            parser.add_argument('--tau',default = 1.5, type=float, help = 'Decay time constant of the calcium indicator, use 0.7 for GCaMP6f, 1.0 for GCaMP6m, 1.25-1.5 for GCaMP6s.')
            parser.add_argument('--no-use-classifier',
                                action='store_true', default = False,
                                help = "[cell selection] do not use a classifier to help deciding good from bad cells")
            parser.add_argument('--no-nonrigid',
                                action='store_true', default = False,
                                help = "[motion correction] perform rigid motion correction")
            parser.add_argument('--no-fissa-denoise',
                                action='store_true', default = False,
                                help = "Use FISSA to get df/f and correct for neuropil contamination")
            parser.add_argument('--spatial-scale',default = 0, type=int,
                                help = 'Spatial scale for the recordings 0: auto, 1: 6 pixels, 2: 12 pixels), 3: 24 pixels, 4: 48 pixels')
            parser.add_argument('--diameter',default = 6, type=int,
                                help = 'Diameter for cell detection (when not running in sparse mode; default is 6)')
            parser.add_argument('--threshold-scaling',default = 1.2, type=float,
                                help = 'Threshold for ROI detection (higher=less cells detected; default is 1.2).')
            parser.add_argument('--no-sparse',default = True,action = "store_false",
                                help = 'Flag to control "sparse_mode" cell detection.')       
            parser.add_argument('--detect-processes',default = False,action = "store_true",
                                help = 'Allow connected components.')       
            parser.add_argument('--roi',
                                action='store', default=None, type = int, nargs = 4,
                                help = "ROI")

            args = parser.parse_args(arguments[1:])

            params = dict(use_builtin_classifier = not args.no_use_classifier,
                          use_fissa = not args.no_fissa_denoise,
                          tau = args.tau,
                          sparse_mode = not args.no_sparse,
                          connected = args.detect_processes,
                          threshold_scaling = args.threshold_scaling,
                          spatial_scale = args.spatial_scale,
                          diameter = args.diameter,
                          nonrigid = not args.no_nonrigid,
                          roi = args.roi)
            self.parameters = params

    def find_datasets(self, subject_name = None, session_name = None):
        '''
        Searches for subjects and sessions in TwoPhoton
        '''
        if subject_name is None and session_name is None:
            print("\n\nPlease specify a 'subject_name' and a 'session_name' to perform segmentation with Suite2p.\n\n")
        keys = []
        if not subject_name is None:
            if len(subject_name) > 1:
                raise ValueError(f'Please submit one subject at a time {subject_name}.')
            if not subject_name[0] == '':
                subject_name = subject_name[0]
        if not session_name is None:
            for s in session_name:
                if not s == '':
                    keys.append(dict(subject_name = subject_name,
                                     session_name = s))
        else:
            # find all sessions that can be segmented
            parameter_set_num, parameters = self._get_parameter_number()
            sessions = np.unique(((
                (self.schema.Miniscope() & f'subject_name = "{subject_name}"') -
                (self.schema.CellSegmentation() & f'parameter_set_num = {parameter_set_num}'))).fetch('session_name'))
            for ses in sessions:
                keys.append(dict(subject_name = subject_name,
                                 session_name = ses))
        datasets = []
        for k in keys:
            datasets += (self.schema.Miniscope()& k).proj('subject_name','session_name','dataset_name').fetch(
                as_dict = True)

        if not len(datasets):
            for k in keys:
                datasets += (self.schema.TwoPhoton()& k).proj('subject_name','session_name','dataset_name').fetch(
                    as_dict = True)
        return datasets

    def _compute(self):
        from numba import set_num_threads # control the number of threads used by suite2p 
        set_num_threads(DEFAULT_N_JOBS)
        from threadpoolctl import threadpool_limits
        import suite2p
        version = f'suite2p{suite2p.version}'
        if self.parameters['use_fissa']:
            import fissa
            version+= f';fissa{fissa.__version__}'

        seskeys = (self.schema.TwoPhoton() & self.dataset_key).proj().fetch(as_dict = True)
        if len((self.schema.File & (self.schema.TwoPhoton & seskeys)).check_if_files_local()[0]) == 0:
            self.files_existed = False
        if len(seskeys):
            rec_filesorig = (self.schema.File & (self.schema.TwoPhoton & seskeys)).get()
            ses_par = (self.schema.TwoPhoton() & seskeys).fetch(as_dict = True)
        else:
            raise(NotImplemented("Only two-photon segmentation is implemented using this compute."))

        nchannels = ses_par[0]['n_channels']
        nplanes = ses_par[0]['n_planes']
        fs = ses_par[0]['frame_rate']

        import string
        rand = ''.join(np.random.choice([s for s in string.ascii_lowercase + string.digits],9))
        savedir = Path(prefs['scratch_path'])/f'suite2p_temporary_{rand}'
        savedir.mkdir(exist_ok=True,parents=True)

        ops = suite2p.default_ops()
        for p in self.parameters.keys():
            if p in ops.keys():
                ops[p] = self.parameters[p]

        # handle different file types
        sbx_files = [f for f in rec_filesorig if str(f).endswith('.sbx')]
        zarr_files = [f for f in rec_filesorig if str(f).endswith('.zarr.zip')]
        if len(zarr_files):
            print('Converting zarr file to tiff.')
            rec_files = [savedir/'rawdata.tif']
            from tifffile import TiffWriter
            with TiffWriter(str(rec_files[0]), bigtiff=True, append=True) as tif:
                from tqdm.auto import tqdm 
                for ifile,zarrfile in enumerate(zarr_files):
                    fd = open_zarr(zarrfile)
                    for o,f in tqdm(chunk_indices(len(fd),512),
                                    desc = f"Concatenating raw data [{ifile}]"):
                        tif.save(np.array(fd[o:f]).reshape((-1,*fd.shape[-2:])))
            ops['input_format'] = 'tif'
        elif len(sbx_files):
            print('File format is SBX.')
            ops['input_format'] = 'sbx'
            rec_files = sbx_files
        else:
            raise(ValueError(f'No valid file format found in {rec_filesorig}.'))

        ops['fs'] = fs
        ops['nplanes'] = nplanes
        ops['nchannels'] = nchannels
        ops['fast_disk'] = str(savedir)
        ops['save_folder'] = str(savedir/'suite2p')
        ops['combined'] = False # do not combine planes
        db = {
            'data_path': [s.parent for s in rec_files],
        }
        for o in ops.keys():
            print(f'    {o} : {ops[o]}')

        with threadpool_limits(limits=DEFAULT_N_JOBS, user_api='blas'):
            suite2p.run_s2p(ops=ops, db = db)

        from labdata.stacks import compute_projections
        planefolders = natsorted([a.parent for a in savedir.rglob('*F.npy') 
                                  if not 'combined' in str(a)])

        offsets = np.cumsum([0,*[int(a) for a in (self.schema.TwoPhoton() & seskeys).fetch('n_frames')]])
        cellsegdicts = []
        planedicts = []
        projdicts = [] 
        roidicts = []
        tracesdicts = []
        rawtracesdicts = []
        deconvdicts = []
        selectiondicts = []
        nframes_projection = 4000
        cellcount = [0 for s in seskeys]
        for iplane,planefolder in enumerate(planefolders):
            F = np.load(planefolder/'F.npy')
            Fneu = np.load(planefolder/'Fneu.npy')
            spks = np.load(planefolder/'spks.npy')
            stats = np.load(planefolder/'stat.npy',allow_pickle=True)
            ops = np.load(planefolder/'ops.npy',allow_pickle=True).item()
            iscell = np.load(planefolder/'iscell.npy', allow_pickle=True)
            # load the binary file
            binaryfile = planefolder/'data.bin'
            nrows = ops['Ly']
            ncols = ops['Lx']
            nframes = int(binaryfile.stat().st_size/nrows/ncols/2)
            binary = np.memmap(binaryfile,shape = (nframes, nrows, ncols),
                               dtype = 'int16',order = 'C')
            dims = binary.shape[1:]
            rois = []
            for icell,s in enumerate(stats):
                ii = np.ravel_multi_index([s['ypix'],s['xpix']],dims)
                rois.append(dict(roi_num = icell,
                                 roi_pixels = ii,
                                 roi_pixels_values = s['lam']))
            # extract and denoise with FISSA!
            if self.parameters['use_fissa']:
                # this is a work around. Save the binary to a tiff file and extract with low-memory mode
                if str(rec_files[0]).endswith('rawdata.tif'):
                    print('Deleting the previously generated tiff file.')
                    os.unlink(rec_files[0])
                correctedfname = savedir/'correcteddata.tif'
                from tifffile import TiffWriter
                with TiffWriter(str(correctedfname), bigtiff=True, append=True) as tif:
                    from tqdm.auto import tqdm 
                    for o,f in tqdm(chunk_indices(len(binary),512),
                                        desc = f"Concatenating motion corrected data:"):
                            tif.save(np.array(binary[o:f]).reshape((-1,*binary.shape[-2:])))
                with threadpool_limits(limits=DEFAULT_N_JOBS, user_api='blas'):
                    dff,spks = extract_dff_fissa(stats,correctedfname,dims,
                                                 batch_size=ops['batch_size'],
                                                 tau=ops['tau'],
                                                 fs=ops['fs'])
            else:
                print('Skipping df/f, there will be no "Traces".')
                dff = None
            iplane = int(iplane)
            for isession, (key,on,off) in enumerate(zip(seskeys,offsets[:-1],offsets[1:])):
                mean_proj,std_proj,max_proj, corr_proj =  compute_projections(binary[on:on+nframes_projection])
                planekey = dict(key,
                            parameter_set_num = self.parameter_set_num,
                            plane_num = iplane)
                planedicts.append(dict(planekey,
                             plane_n_rois = len(rois),
                             dims = [a for a in dims]))
                cellcount[isession] += len(rois) # count cells in all planes
                projdicts.extend([dict(planekey, proj_name = n,
                                       proj_im = pi) for n,pi in zip(['mean','std','max','lcorr'],
                                                                     [mean_proj,std_proj,max_proj,corr_proj])])
                for icell,roi in enumerate(rois):
                    roidicts.append(dict(planekey,**roi))
                    if not dff is None:
                        tracesdicts.append(dict(planekey,
                                                roi_num = icell,
                                                dff = dff[icell].astype(np.float32)[on:off]))
                    rawtracesdicts.append(dict(planekey,
                                              roi_num = icell,
                                              f_trace = F[icell].astype(np.float32)[on:off],
                                              f_neuropil = Fneu[icell].astype(np.float32)[on:off]))
                    deconvdicts.append(dict(planekey,
                                            roi_num = icell,
                                            deconv = spks[icell].astype(np.float32)[on:off]))
                    selectiondicts.append(dict(planekey,
                                               roi_num = icell,
                                               selection_method = 'auto',
                                               selection = bool(iscell[icell,0]),
                                               likelihood = iscell[icell,1]))                  
        datenow = datetime.now()
        for key,ncells in zip(seskeys,cellcount):
            cellsegdicts.append(dict(key,parameter_set_num = self.parameter_set_num,
                                     n_rois = ncells,
                                     crop_region = None,
                                     algorithm_version = version,
                                     segmentation_datetime = datenow))
        self.schema.CellSegmentation.insert(cellsegdicts, allow_direct_insert = True)
        self.schema.CellSegmentation.Plane.insert(planedicts,allow_direct_insert = True)
        # if not params['pw_rigid']: # Not implemented
        #     motion = np.array(shifts).astype(np.float32)
        #     CellSegmentation.MotionCorrection.insert1(dict(planekey,
        #                                                    motion_block_size = 0,
        #                                                    displacement = motion))
        self.schema.CellSegmentation.Projection.insert(projdicts,allow_direct_insert = True)
        self.schema.CellSegmentation.ROI.insert(roidicts ,allow_direct_insert = True)
        # Insert traces in parallel to prevent timeout errors from mysql

        parallel_insert(self.schema.schema_project,'CellSegmentation.Traces',tracesdicts, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)
        parallel_insert(self.schema.schema_project,'CellSegmentation.RawTraces',rawtracesdicts, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)
        parallel_insert(self.schema.schema_project,'CellSegmentation.Deconvolved',deconvdicts, n_jobs = DEFAULT_N_JOBS,
                        skip_duplicates = True, ignore_extra_fields = True)                        
        self.schema.CellSegmentation.Selection.insert(selectiondicts,allow_direct_insert = True)
        if not self.keep_intermediate:
            print(f'[{self.name} job] Removing the temporary folder {savedir}.')
            import shutil
            shutil.rmtree(savedir)
            if not self.files_existed:
                for f in rec_filesorig:
                    os.unlink(f)
        else:
            print(f'[{self.name} job] Kept the temporary folder {temporary_folder}.')

__init__(job_id, project=None, allow_s3=None, **kwargs)

This class runs Suite2p and FISSA neuropil decontamination on a dataset, which can be used for 2p data.

Source code in labdata/compute/suite2p.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self,job_id, project = None, allow_s3 = None, **kwargs):
    '''
    This class runs Suite2p and FISSA neuropil decontamination on a dataset, which can be used for 2p data. 
    '''
    super(Suite2pCompute,self).__init__(job_id,project = project, allow_s3 = allow_s3)
    self.file_filters = ['.zarr.zip','.sbx','.tiff'] #works with zarr, sbx and tiff
    # default parameters
    self.parameters = dict(algorithm_name = 'suite2p+fissa')
    # will only store these in CellSegmentationParams
    self.parameter_keys = ["use_builtin_classifier","use_fissa","tau","nonrigid",
                           "sparse_mode","connected","spatial_scale","diameter",
                           "threshold_scaling"]

    self.parameter_set_num = None # identifier in CellSegmentationParams

    self._init_job()
    if not self.job_id is None:
        self.add_parameter_key()

find_datasets(subject_name=None, session_name=None)

Searches for subjects and sessions in TwoPhoton

Source code in labdata/compute/suite2p.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def find_datasets(self, subject_name = None, session_name = None):
    '''
    Searches for subjects and sessions in TwoPhoton
    '''
    if subject_name is None and session_name is None:
        print("\n\nPlease specify a 'subject_name' and a 'session_name' to perform segmentation with Suite2p.\n\n")
    keys = []
    if not subject_name is None:
        if len(subject_name) > 1:
            raise ValueError(f'Please submit one subject at a time {subject_name}.')
        if not subject_name[0] == '':
            subject_name = subject_name[0]
    if not session_name is None:
        for s in session_name:
            if not s == '':
                keys.append(dict(subject_name = subject_name,
                                 session_name = s))
    else:
        # find all sessions that can be segmented
        parameter_set_num, parameters = self._get_parameter_number()
        sessions = np.unique(((
            (self.schema.Miniscope() & f'subject_name = "{subject_name}"') -
            (self.schema.CellSegmentation() & f'parameter_set_num = {parameter_set_num}'))).fetch('session_name'))
        for ses in sessions:
            keys.append(dict(subject_name = subject_name,
                             session_name = ses))
    datasets = []
    for k in keys:
        datasets += (self.schema.Miniscope()& k).proj('subject_name','session_name','dataset_name').fetch(
            as_dict = True)

    if not len(datasets):
        for k in keys:
            datasets += (self.schema.TwoPhoton()& k).proj('subject_name','session_name','dataset_name').fetch(
                as_dict = True)
    return datasets

run_analysis(target, jobids, compute_obj, project=None)

Launches a set if analysis on a specific target

Source code in labdata/compute/utils.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def run_analysis(target, jobids, compute_obj, project = None):
    '''
    Launches a set if analysis on a specific target
    '''

    from .singularity import run_on_apptainer
    container_file = (Path(prefs['compute']['containers']['local'])/compute_obj.container).with_suffix('.sif')
    def _get_cmds(jobids,
                  project,
                  container_file = container_file,
                  bind = [],
                  bind_from_prefs = True):
        cmds = []
        from shutil import which
        for j in jobids:
            cc = f'labdata2 task {j}'
            if not project is None:
                cc += f' -p {project}'
            if container_file.exists() and which('apptainer'):
                cmds.append(run_on_apptainer(container_file,
                                             command = cc,
                                             cuda = compute_obj.cuda,
                                             bind = bind,
                                             bind_from_prefs = bind_from_prefs,
                                             dry_run = True))
            else:
                cmds.append(cc)
        return cmds
    task_host = prefs['hostname']
    if hasattr(compute_obj,'schema'): # reuse the schema
        schema = compute_obj.schema
    else:
        schema = load_project_schema(project)
    if target == 'slurm':  # run with slurm
        from .schedulers import slurm_exists, slurm_submit
        if slurm_exists():
            for jid,cmd in zip(jobids,_get_cmds(jobids,project)):
                begin = None
                files_archived = check_archived(jid,check_local = True, schema = schema)
                if files_archived:
                    begin = "now+5hour"
                    print('Delaying job for 5 hours (retrieve archive). ')
                    schema.ComputeTask.update1(dict(job_id = jid,task_status = 'WAITING (ARCHIVE)'))

                if container_file.exists():

                    cmd += ' | ' +  run_on_apptainer(container_file,
                                                     command = f'labdata2 logpipe {jid}',
                                                     dry_run = True)
                else:
                    cmd += f' | labdata2 logpipe {jid}'
                slurmjob = slurm_submit(compute_obj.name,
                                        cmd,
                                        begin = begin,
                                        ntasks = 1,
                                        ncpuspertask = DEFAULT_N_JOBS, # change later to be called by the job.
                                        gpus = 1 if compute_obj.cuda else None,
                                        project = project)
                print(f'Submitted {tcolor["g"](compute_obj.name)} {tcolor["y"](jid)} to slurm [{tcolor["y"](slurmjob)}]')
                schema.ComputeTask.update1(dict(job_id = jid,
                                         task_host = task_host + f'@{slurmjob}',
                                         task_target = target))
        else:
            print(f'{tcolor["r"]("Could not find SLURM: did not submit compute tasks:")}')
            print('\t\n'.join(cmds))
    elif target == 'local':  # run locally without scheduler
        for job_id in jobids:
            task = handle_compute(job_id, project = project)
            task.compute()

    elif 'ec2' in target:   # launch dedicated instance on AWS
        # TODO: delayed begin not used here.
        from .ec2 import ec2_cmd_for_launch,ec2_create_instance,ec2_connect
        session,ec2 = ec2_connect()
        for jid, cmd in zip(jobids,_get_cmds(
                jobids,
                project,
                compute_obj.cuda,
                container_file = Path('idontexist'))):
            cmd = ec2_cmd_for_launch(compute_obj.container,
                                     cmd,
                                     singularity_cuda = compute_obj.cuda,
                                     append_log = jid)
            # check if the target contains the words small or large

            instance_type = target.replace('ec2-','')
            if instance_type in compute_obj.ec2.keys():
                instance_opts = compute_obj.ec2[instance_type]
            else:
                # using small instance
                instance_opts = compute_obj.ec2['small']
            ins = ec2_create_instance(ec2, user_data = cmd,
                                      **instance_opts)
            print(f'Submitted job to {tcolor["r"](ins["id"])} on an ec2 {instance_opts}')
            task_host = ins["id"]
            schema.ComputeTask.update1(dict(job_id = jid,
                                     task_host = task_host,
                                     task_target = target))
    else:
        # check if there are remote services to launch
        if 'remotes' in prefs['compute'].keys():
            names = prefs['compute']['remotes'].keys()
            targetname = str(target)
            if target in names:
                target = prefs['compute']['remotes'][target]
            else:
                raise(ValueError(f'Could not find target [{target}]'))
            from .schedulers import ssh_connect,slurm_schedule_remote
            container_file = f"$LABDATA_PATH/containers/{compute_obj.container}.sif"
            with ssh_connect(target['address'],target['user'],target['permission_key']) as conn:
                for j in jobids:
                    begin = None
                    files_archived = check_archived(j,schema = schema)
                    if files_archived:
                        begin = "now+5hour"
                        print('Delaying job for 5 hours (retrieve archive). ')
                        schema.ComputeTask.update1(dict(job_id = j,task_status = 'WAITING (ARCHIVE)'))
                    # needs to have LABDATA_PATH defined in the remote
                    cmd = run_on_apptainer(container_file,
                                           command = f'labdata2 task {j}',
                                           cuda = compute_obj.cuda,
                                           dry_run = True)
                    # generate slurm cmd and launch
                    cmd += ' | ' + run_on_apptainer(container_file,
                                                  command = f'labdata2 logpipe {j}',
                                                  dry_run = True)
                    opts = dict()
                    nt = str(targetname)
                    if compute_obj.name in target['analysis_options']:
                        opts = target['analysis_options'][compute_obj.name]
                        nt = f'{targetname}@{opts["queue"]}'
                    if 'pre' in opts.keys(): # this needs to be a list of things to add to the list of pre_cmds
                        target['pre_cmds'] += opts['pre']
                    slurmjob = slurm_schedule_remote(cmd,  
                                                     conn = conn,
                                                     begin = begin,
                                                     jobname = compute_obj.name+f'_{j}',
                                                     pre_cmds = target['pre_cmds'],
                                                     #remote_dir = '$LABDATA_PATH/remote_jobs',
                                                     container_path = container_file,
                                                     project = project,
                                                     database_user = prefs['database']['database.user'],
                                                     database_password = decrypt_string(prefs['database']['database.password'].replace('encrypted:','')) if 'encrypted:' in prefs['database']['database.password'] else prefs['database']['database.password'],
                                                     #key_dir = '$LABDATA_PATH',
                                                     **opts)
                    if not slurmjob is None:
                        print(f'Submitted {tcolor["r"](compute_obj.name)} job {tcolor["y"](j)} to {tcolor["y"](nt)}[{tcolor["y"](slurmjob)}]')
                    schema.ComputeTask.update1(dict(job_id = j,
                                             task_target = nt))

handle_compute(job_id, project=None)

Source code in labdata/compute/utils.py
20
21
22
23
24
25
26
27
28
29
def handle_compute(job_id,project = None):
    schema = load_project_schema(project)
    jobinfo = pd.DataFrame((schema.ComputeTask() & dict(job_id = job_id)).fetch())
    if not len(jobinfo):
        print(f'No task with id: {job_id}')
    jobinfo = jobinfo.iloc[0]
    if jobinfo.task_waiting == 0:
        print(f'Task {job_id} is running on {jobinfo.task_host}')
    obj = load_analysis_object(jobinfo.task_name)(jobinfo.job_id,project)
    return obj